Spaces:
Runtime error
Runtime error
Commit
·
63775f2
1
Parent(s):
c2c01a0
add necessary file
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- anonymous_demo/__init__.py +5 -0
- anonymous_demo/core/__init__.py +0 -0
- anonymous_demo/core/tad/__init__.py +0 -0
- anonymous_demo/core/tad/classic/__bert__/README.MD +3 -0
- anonymous_demo/core/tad/classic/__bert__/__init__.py +1 -0
- anonymous_demo/core/tad/classic/__bert__/dataset_utils/__init__.py +0 -0
- anonymous_demo/core/tad/classic/__bert__/dataset_utils/data_utils_for_inference.py +121 -0
- anonymous_demo/core/tad/classic/__bert__/models/__init__.py +1 -0
- anonymous_demo/core/tad/classic/__bert__/models/tad_bert.py +46 -0
- anonymous_demo/core/tad/classic/__init__.py +0 -0
- anonymous_demo/core/tad/models/__init__.py +9 -0
- anonymous_demo/core/tad/prediction/__init__.py +0 -0
- anonymous_demo/core/tad/prediction/tad_classifier.py +518 -0
- anonymous_demo/functional/__init__.py +3 -0
- anonymous_demo/functional/checkpoint/__init__.py +1 -0
- anonymous_demo/functional/checkpoint/checkpoint_manager.py +19 -0
- anonymous_demo/functional/config/__init__.py +1 -0
- anonymous_demo/functional/config/config_manager.py +64 -0
- anonymous_demo/functional/config/tad_config_manager.py +229 -0
- anonymous_demo/functional/dataset/__init__.py +1 -0
- anonymous_demo/functional/dataset/dataset_manager.py +45 -0
- anonymous_demo/network/__init__.py +0 -0
- anonymous_demo/network/lcf_pooler.py +28 -0
- anonymous_demo/network/lsa.py +73 -0
- anonymous_demo/network/sa_encoder.py +199 -0
- anonymous_demo/utils/__init__.py +0 -0
- anonymous_demo/utils/demo_utils.py +247 -0
- anonymous_demo/utils/logger.py +38 -0
- checkpoints.zip +3 -0
- text_defense/201.SST2/stsa.binary.dev.dat +0 -0
- text_defense/201.SST2/stsa.binary.test.dat +0 -0
- text_defense/201.SST2/stsa.binary.train.dat +0 -0
- text_defense/202.IMDB10K/imdb10k.test.dat +0 -0
- text_defense/202.IMDB10K/imdb10k.train.dat +0 -0
- text_defense/202.IMDB10K/imdb10k.valid.dat +0 -0
- text_defense/204.AGNews10K/AGNews10K.test.dat +0 -0
- text_defense/204.AGNews10K/AGNews10K.train.dat +0 -0
- text_defense/204.AGNews10K/AGNews10K.valid.dat +0 -0
- text_defense/206.Amazon_Review_Polarity10K/amazon.test.dat +0 -0
- text_defense/206.Amazon_Review_Polarity10K/amazon.train.dat +0 -0
- textattack/__init__.py +39 -0
- textattack/__main__.py +6 -0
- textattack/attack.py +492 -0
- textattack/attack_args.py +763 -0
- textattack/attack_recipes/__init__.py +43 -0
- textattack/attack_recipes/a2t_yoo_2021.py +74 -0
- textattack/attack_recipes/attack_recipe.py +30 -0
- textattack/attack_recipes/bae_garg_2019.py +123 -0
- textattack/attack_recipes/bert_attack_li_2020.py +95 -0
- textattack/attack_recipes/checklist_ribeiro_2020.py +53 -0
anonymous_demo/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__version__ = "1.0.0"
|
| 2 |
+
|
| 3 |
+
__name__ = "anonymous_demo"
|
| 4 |
+
|
| 5 |
+
from anonymous_demo.functional import TADCheckpointManager
|
anonymous_demo/core/__init__.py
ADDED
|
File without changes
|
anonymous_demo/core/tad/__init__.py
ADDED
|
File without changes
|
anonymous_demo/core/tad/classic/__bert__/README.MD
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## This is the simple migration from ABSA-PyTorch under MIT license
|
| 2 |
+
|
| 3 |
+
Project Address: https://github.com/songyouwei/ABSA-PyTorch
|
anonymous_demo/core/tad/classic/__bert__/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .models import *
|
anonymous_demo/core/tad/classic/__bert__/dataset_utils/__init__.py
ADDED
|
File without changes
|
anonymous_demo/core/tad/classic/__bert__/dataset_utils/data_utils_for_inference.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tqdm
|
| 2 |
+
from findfile import find_cwd_dir
|
| 3 |
+
from torch.utils.data import Dataset
|
| 4 |
+
from transformers import AutoTokenizer
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Tokenizer4Pretraining:
|
| 8 |
+
def __init__(self, max_seq_len, opt, **kwargs):
|
| 9 |
+
if kwargs.pop("offline", False):
|
| 10 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 11 |
+
find_cwd_dir(opt.pretrained_bert.split("/")[-1]),
|
| 12 |
+
do_lower_case="uncased" in opt.pretrained_bert,
|
| 13 |
+
)
|
| 14 |
+
else:
|
| 15 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 16 |
+
opt.pretrained_bert, do_lower_case="uncased" in opt.pretrained_bert
|
| 17 |
+
)
|
| 18 |
+
self.max_seq_len = max_seq_len
|
| 19 |
+
|
| 20 |
+
def text_to_sequence(self, text, reverse=False, padding="post", truncating="post"):
|
| 21 |
+
return self.tokenizer.encode(
|
| 22 |
+
text,
|
| 23 |
+
truncation=True,
|
| 24 |
+
padding="max_length",
|
| 25 |
+
max_length=self.max_seq_len,
|
| 26 |
+
return_tensors="pt",
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class BERTTADDataset(Dataset):
|
| 31 |
+
def __init__(self, tokenizer, opt):
|
| 32 |
+
self.bert_baseline_input_colses = {"bert": ["text_bert_indices"]}
|
| 33 |
+
|
| 34 |
+
self.tokenizer = tokenizer
|
| 35 |
+
self.opt = opt
|
| 36 |
+
self.all_data = []
|
| 37 |
+
|
| 38 |
+
def parse_sample(self, text):
|
| 39 |
+
return [text]
|
| 40 |
+
|
| 41 |
+
def prepare_infer_sample(self, text: str, ignore_error):
|
| 42 |
+
self.process_data(self.parse_sample(text), ignore_error=ignore_error)
|
| 43 |
+
|
| 44 |
+
def process_data(self, samples, ignore_error=True):
|
| 45 |
+
all_data = []
|
| 46 |
+
if len(samples) > 100:
|
| 47 |
+
it = tqdm.tqdm(
|
| 48 |
+
samples, postfix="preparing text classification inference dataloader..."
|
| 49 |
+
)
|
| 50 |
+
else:
|
| 51 |
+
it = samples
|
| 52 |
+
for text in it:
|
| 53 |
+
try:
|
| 54 |
+
# handle for empty lines in inference datasets
|
| 55 |
+
if text is None or "" == text.strip():
|
| 56 |
+
raise RuntimeError("Invalid Input!")
|
| 57 |
+
|
| 58 |
+
if "!ref!" in text:
|
| 59 |
+
text, _, labels = text.strip().partition("!ref!")
|
| 60 |
+
text = text.strip()
|
| 61 |
+
if labels.count(",") == 2:
|
| 62 |
+
label, is_adv, adv_train_label = labels.strip().split(",")
|
| 63 |
+
label, is_adv, adv_train_label = (
|
| 64 |
+
label.strip(),
|
| 65 |
+
is_adv.strip(),
|
| 66 |
+
adv_train_label.strip(),
|
| 67 |
+
)
|
| 68 |
+
elif labels.count(",") == 1:
|
| 69 |
+
label, is_adv = labels.strip().split(",")
|
| 70 |
+
label, is_adv = label.strip(), is_adv.strip()
|
| 71 |
+
adv_train_label = "-100"
|
| 72 |
+
elif labels.count(",") == 0:
|
| 73 |
+
label = labels.strip()
|
| 74 |
+
adv_train_label = "-100"
|
| 75 |
+
is_adv = "-100"
|
| 76 |
+
else:
|
| 77 |
+
label = "-100"
|
| 78 |
+
adv_train_label = "-100"
|
| 79 |
+
is_adv = "-100"
|
| 80 |
+
|
| 81 |
+
label = int(label)
|
| 82 |
+
adv_train_label = int(adv_train_label)
|
| 83 |
+
is_adv = int(is_adv)
|
| 84 |
+
|
| 85 |
+
else:
|
| 86 |
+
text = text.strip()
|
| 87 |
+
label = -100
|
| 88 |
+
adv_train_label = -100
|
| 89 |
+
is_adv = -100
|
| 90 |
+
|
| 91 |
+
text_indices = self.tokenizer.text_to_sequence("{}".format(text))
|
| 92 |
+
|
| 93 |
+
data = {
|
| 94 |
+
"text_bert_indices": text_indices[0],
|
| 95 |
+
"text_raw": text,
|
| 96 |
+
"label": label,
|
| 97 |
+
"adv_train_label": adv_train_label,
|
| 98 |
+
"is_adv": is_adv,
|
| 99 |
+
# 'label': self.opt.label_to_index.get(label, -100) if isinstance(label, str) else label,
|
| 100 |
+
#
|
| 101 |
+
# 'adv_train_label': self.opt.adv_train_label_to_index.get(adv_train_label, -100) if isinstance(adv_train_label, str) else adv_train_label,
|
| 102 |
+
#
|
| 103 |
+
# 'is_adv': self.opt.is_adv_to_index.get(is_adv, -100) if isinstance(is_adv, str) else is_adv,
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
all_data.append(data)
|
| 107 |
+
|
| 108 |
+
except Exception as e:
|
| 109 |
+
if ignore_error:
|
| 110 |
+
print("Ignore error while processing:", text)
|
| 111 |
+
else:
|
| 112 |
+
raise e
|
| 113 |
+
|
| 114 |
+
self.all_data = all_data
|
| 115 |
+
return self.all_data
|
| 116 |
+
|
| 117 |
+
def __getitem__(self, index):
|
| 118 |
+
return self.all_data[index]
|
| 119 |
+
|
| 120 |
+
def __len__(self):
|
| 121 |
+
return len(self.all_data)
|
anonymous_demo/core/tad/classic/__bert__/models/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .tad_bert import TADBERT
|
anonymous_demo/core/tad/classic/__bert__/models/tad_bert.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from transformers.models.bert.modeling_bert import BertPooler
|
| 4 |
+
|
| 5 |
+
from anonymous_demo.network.sa_encoder import Encoder
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class TADBERT(nn.Module):
|
| 9 |
+
inputs = ["text_bert_indices"]
|
| 10 |
+
|
| 11 |
+
def __init__(self, bert, opt):
|
| 12 |
+
super(TADBERT, self).__init__()
|
| 13 |
+
self.opt = opt
|
| 14 |
+
self.bert = bert
|
| 15 |
+
self.pooler = BertPooler(bert.config)
|
| 16 |
+
self.dense1 = nn.Linear(self.opt.hidden_dim, self.opt.class_dim)
|
| 17 |
+
self.dense2 = nn.Linear(self.opt.hidden_dim, self.opt.adv_det_dim)
|
| 18 |
+
self.dense3 = nn.Linear(self.opt.hidden_dim, self.opt.class_dim)
|
| 19 |
+
|
| 20 |
+
self.encoder1 = Encoder(self.bert.config, opt=opt)
|
| 21 |
+
self.encoder2 = Encoder(self.bert.config, opt=opt)
|
| 22 |
+
self.encoder3 = Encoder(self.bert.config, opt=opt)
|
| 23 |
+
|
| 24 |
+
def forward(self, inputs):
|
| 25 |
+
text_raw_indices = inputs[0]
|
| 26 |
+
last_hidden_state = self.bert(text_raw_indices)["last_hidden_state"]
|
| 27 |
+
|
| 28 |
+
sent_logits = self.dense1(self.pooler(last_hidden_state))
|
| 29 |
+
advdet_logits = self.dense2(self.pooler(last_hidden_state))
|
| 30 |
+
adv_tr_logits = self.dense3(self.pooler(last_hidden_state))
|
| 31 |
+
|
| 32 |
+
att_score = torch.nn.functional.normalize(
|
| 33 |
+
last_hidden_state.abs().sum(dim=1, keepdim=False)
|
| 34 |
+
- last_hidden_state.abs().min(dim=1, keepdim=True)[0],
|
| 35 |
+
p=1,
|
| 36 |
+
dim=1,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
outputs = {
|
| 40 |
+
"sent_logits": sent_logits,
|
| 41 |
+
"advdet_logits": advdet_logits,
|
| 42 |
+
"adv_tr_logits": adv_tr_logits,
|
| 43 |
+
"last_hidden_state": last_hidden_state,
|
| 44 |
+
"att_score": att_score,
|
| 45 |
+
}
|
| 46 |
+
return outputs
|
anonymous_demo/core/tad/classic/__init__.py
ADDED
|
File without changes
|
anonymous_demo/core/tad/models/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import anonymous_demo.core.tad.classic.__bert__.models
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class BERTTADModelList(list):
|
| 5 |
+
TADBERT = anonymous_demo.core.tad.classic.__bert__.TADBERT
|
| 6 |
+
|
| 7 |
+
def __init__(self):
|
| 8 |
+
model_list = [self.TADBERT]
|
| 9 |
+
super().__init__(model_list)
|
anonymous_demo/core/tad/prediction/__init__.py
ADDED
|
File without changes
|
anonymous_demo/core/tad/prediction/tad_classifier.py
ADDED
|
@@ -0,0 +1,518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import pickle
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import tqdm
|
| 8 |
+
from findfile import find_file, find_cwd_dir
|
| 9 |
+
from termcolor import colored
|
| 10 |
+
|
| 11 |
+
from torch.utils.data import DataLoader
|
| 12 |
+
from transformers import (
|
| 13 |
+
AutoTokenizer,
|
| 14 |
+
AutoModel,
|
| 15 |
+
AutoConfig,
|
| 16 |
+
DebertaV2ForMaskedLM,
|
| 17 |
+
RobertaForMaskedLM,
|
| 18 |
+
BertForMaskedLM,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
from ....functional.dataset.dataset_manager import detect_infer_dataset
|
| 22 |
+
|
| 23 |
+
from ..models import BERTTADModelList
|
| 24 |
+
from ..classic.__bert__.dataset_utils.data_utils_for_inference import (
|
| 25 |
+
BERTTADDataset,
|
| 26 |
+
Tokenizer4Pretraining,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
from ....utils.demo_utils import (
|
| 30 |
+
print_args,
|
| 31 |
+
TransformerConnectionError,
|
| 32 |
+
get_device,
|
| 33 |
+
build_embedding_matrix,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def init_attacker(tad_classifier, defense):
|
| 38 |
+
try:
|
| 39 |
+
from textattack import Attacker
|
| 40 |
+
from textattack.attack_recipes import (
|
| 41 |
+
BAEGarg2019,
|
| 42 |
+
PWWSRen2019,
|
| 43 |
+
TextFoolerJin2019,
|
| 44 |
+
PSOZang2020,
|
| 45 |
+
IGAWang2019,
|
| 46 |
+
GeneticAlgorithmAlzantot2018,
|
| 47 |
+
DeepWordBugGao2018,
|
| 48 |
+
)
|
| 49 |
+
from textattack.datasets import Dataset
|
| 50 |
+
from textattack.models.wrappers import HuggingFaceModelWrapper
|
| 51 |
+
|
| 52 |
+
class DemoModelWrapper(HuggingFaceModelWrapper):
|
| 53 |
+
def __init__(self, model):
|
| 54 |
+
self.model = model # pipeline = pipeline
|
| 55 |
+
|
| 56 |
+
def __call__(self, text_inputs, **kwargs):
|
| 57 |
+
outputs = []
|
| 58 |
+
for text_input in text_inputs:
|
| 59 |
+
raw_outputs = self.model.infer(
|
| 60 |
+
text_input, print_result=False, **kwargs
|
| 61 |
+
)
|
| 62 |
+
outputs.append(raw_outputs["probs"])
|
| 63 |
+
return outputs
|
| 64 |
+
|
| 65 |
+
class SentAttacker:
|
| 66 |
+
def __init__(self, model, recipe_class=BAEGarg2019):
|
| 67 |
+
model = model
|
| 68 |
+
model_wrapper = DemoModelWrapper(model)
|
| 69 |
+
|
| 70 |
+
recipe = recipe_class.build(model_wrapper)
|
| 71 |
+
|
| 72 |
+
_dataset = [("", 0)]
|
| 73 |
+
_dataset = Dataset(_dataset)
|
| 74 |
+
|
| 75 |
+
self.attacker = Attacker(recipe, _dataset)
|
| 76 |
+
|
| 77 |
+
attackers = {
|
| 78 |
+
"bae": BAEGarg2019,
|
| 79 |
+
"pwws": PWWSRen2019,
|
| 80 |
+
"textfooler": TextFoolerJin2019,
|
| 81 |
+
"pso": PSOZang2020,
|
| 82 |
+
"iga": IGAWang2019,
|
| 83 |
+
"ga": GeneticAlgorithmAlzantot2018,
|
| 84 |
+
"wordbugger": DeepWordBugGao2018,
|
| 85 |
+
}
|
| 86 |
+
return SentAttacker(tad_classifier, attackers[defense])
|
| 87 |
+
except Exception as e:
|
| 88 |
+
print("Original error:", e)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def get_mlm_and_tokenizer(text_classifier, config):
|
| 92 |
+
if isinstance(text_classifier, TADTextClassifier):
|
| 93 |
+
base_model = text_classifier.model.bert.base_model
|
| 94 |
+
else:
|
| 95 |
+
base_model = text_classifier.bert.base_model
|
| 96 |
+
pretrained_config = AutoConfig.from_pretrained(config.pretrained_bert)
|
| 97 |
+
if "deberta-v3" in config.pretrained_bert:
|
| 98 |
+
MLM = DebertaV2ForMaskedLM(pretrained_config)
|
| 99 |
+
MLM.deberta = base_model
|
| 100 |
+
elif "roberta" in config.pretrained_bert:
|
| 101 |
+
MLM = RobertaForMaskedLM(pretrained_config)
|
| 102 |
+
MLM.roberta = base_model
|
| 103 |
+
else:
|
| 104 |
+
MLM = BertForMaskedLM(pretrained_config)
|
| 105 |
+
MLM.bert = base_model
|
| 106 |
+
return MLM, AutoTokenizer.from_pretrained(config.pretrained_bert)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class TADTextClassifier:
|
| 110 |
+
def __init__(self, model_arg=None, cal_perplexity=False, **kwargs):
|
| 111 |
+
"""
|
| 112 |
+
from_train_model: load inference model from trained model
|
| 113 |
+
"""
|
| 114 |
+
self.cal_perplexity = cal_perplexity
|
| 115 |
+
# load from a training
|
| 116 |
+
if not isinstance(model_arg, str):
|
| 117 |
+
print("Load text classifier from training")
|
| 118 |
+
self.model = model_arg[0]
|
| 119 |
+
self.opt = model_arg[1]
|
| 120 |
+
self.tokenizer = model_arg[2]
|
| 121 |
+
else:
|
| 122 |
+
try:
|
| 123 |
+
if "fine-tuned" in model_arg:
|
| 124 |
+
raise ValueError(
|
| 125 |
+
"Do not support to directly load a fine-tuned model, please load a .state_dict or .model instead!"
|
| 126 |
+
)
|
| 127 |
+
print("Load text classifier from", model_arg)
|
| 128 |
+
state_dict_path = find_file(
|
| 129 |
+
model_arg, key=".state_dict", exclude_key=["__MACOSX"]
|
| 130 |
+
)
|
| 131 |
+
model_path = find_file(
|
| 132 |
+
model_arg, key=".model", exclude_key=["__MACOSX"]
|
| 133 |
+
)
|
| 134 |
+
tokenizer_path = find_file(
|
| 135 |
+
model_arg, key=".tokenizer", exclude_key=["__MACOSX"]
|
| 136 |
+
)
|
| 137 |
+
config_path = find_file(
|
| 138 |
+
model_arg, key=".config", exclude_key=["__MACOSX"]
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
print("config: {}".format(config_path))
|
| 142 |
+
print("state_dict: {}".format(state_dict_path))
|
| 143 |
+
print("model: {}".format(model_path))
|
| 144 |
+
print("tokenizer: {}".format(tokenizer_path))
|
| 145 |
+
|
| 146 |
+
with open(config_path, mode="rb") as f:
|
| 147 |
+
self.opt = pickle.load(f)
|
| 148 |
+
self.opt.device = get_device(kwargs.pop("auto_device", True))[0]
|
| 149 |
+
|
| 150 |
+
if state_dict_path or model_path:
|
| 151 |
+
if hasattr(BERTTADModelList, self.opt.model.__name__):
|
| 152 |
+
if state_dict_path:
|
| 153 |
+
if kwargs.pop("offline", False):
|
| 154 |
+
self.bert = AutoModel.from_pretrained(
|
| 155 |
+
find_cwd_dir(
|
| 156 |
+
self.opt.pretrained_bert.split("/")[-1]
|
| 157 |
+
)
|
| 158 |
+
)
|
| 159 |
+
else:
|
| 160 |
+
self.bert = AutoModel.from_pretrained(
|
| 161 |
+
self.opt.pretrained_bert
|
| 162 |
+
)
|
| 163 |
+
self.model = self.opt.model(self.bert, self.opt)
|
| 164 |
+
self.model.load_state_dict(
|
| 165 |
+
torch.load(state_dict_path, map_location="cpu")
|
| 166 |
+
)
|
| 167 |
+
elif model_path:
|
| 168 |
+
self.model = torch.load(model_path, map_location="cpu")
|
| 169 |
+
|
| 170 |
+
try:
|
| 171 |
+
self.tokenizer = Tokenizer4Pretraining(
|
| 172 |
+
max_seq_len=self.opt.max_seq_len, opt=self.opt, **kwargs
|
| 173 |
+
)
|
| 174 |
+
except ValueError:
|
| 175 |
+
if tokenizer_path:
|
| 176 |
+
with open(tokenizer_path, mode="rb") as f:
|
| 177 |
+
self.tokenizer = pickle.load(f)
|
| 178 |
+
else:
|
| 179 |
+
raise TransformerConnectionError()
|
| 180 |
+
|
| 181 |
+
except Exception as e:
|
| 182 |
+
raise RuntimeError(
|
| 183 |
+
"Exception: {} Fail to load the model from {}! ".format(
|
| 184 |
+
e, model_arg
|
| 185 |
+
)
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
self.infer_dataloader = None
|
| 189 |
+
self.opt.eval_batch_size = kwargs.pop("eval_batch_size", 128)
|
| 190 |
+
|
| 191 |
+
self.opt.initializer = self.opt.initializer
|
| 192 |
+
|
| 193 |
+
if self.cal_perplexity:
|
| 194 |
+
try:
|
| 195 |
+
self.MLM, self.MLM_tokenizer = get_mlm_and_tokenizer(self, self.opt)
|
| 196 |
+
except Exception as e:
|
| 197 |
+
self.MLM, self.MLM_tokenizer = None, None
|
| 198 |
+
|
| 199 |
+
self.to(self.opt.device)
|
| 200 |
+
|
| 201 |
+
def to(self, device=None):
|
| 202 |
+
self.opt.device = device
|
| 203 |
+
self.model.to(device)
|
| 204 |
+
if hasattr(self, "MLM"):
|
| 205 |
+
self.MLM.to(self.opt.device)
|
| 206 |
+
|
| 207 |
+
def cpu(self):
|
| 208 |
+
self.opt.device = "cpu"
|
| 209 |
+
self.model.to("cpu")
|
| 210 |
+
if hasattr(self, "MLM"):
|
| 211 |
+
self.MLM.to("cpu")
|
| 212 |
+
|
| 213 |
+
def cuda(self, device="cuda:0"):
|
| 214 |
+
self.opt.device = device
|
| 215 |
+
self.model.to(device)
|
| 216 |
+
if hasattr(self, "MLM"):
|
| 217 |
+
self.MLM.to(device)
|
| 218 |
+
|
| 219 |
+
def _log_write_args(self):
|
| 220 |
+
n_trainable_params, n_nontrainable_params = 0, 0
|
| 221 |
+
for p in self.model.parameters():
|
| 222 |
+
n_params = torch.prod(torch.tensor(p.shape))
|
| 223 |
+
if p.requires_grad:
|
| 224 |
+
n_trainable_params += n_params
|
| 225 |
+
else:
|
| 226 |
+
n_nontrainable_params += n_params
|
| 227 |
+
print(
|
| 228 |
+
"n_trainable_params: {0}, n_nontrainable_params: {1}".format(
|
| 229 |
+
n_trainable_params, n_nontrainable_params
|
| 230 |
+
)
|
| 231 |
+
)
|
| 232 |
+
for arg in vars(self.opt):
|
| 233 |
+
if getattr(self.opt, arg) is not None:
|
| 234 |
+
print(">>> {0}: {1}".format(arg, getattr(self.opt, arg)))
|
| 235 |
+
|
| 236 |
+
def batch_infer(
|
| 237 |
+
self,
|
| 238 |
+
target_file=None,
|
| 239 |
+
print_result=True,
|
| 240 |
+
save_result=False,
|
| 241 |
+
ignore_error=True,
|
| 242 |
+
defense: str = None,
|
| 243 |
+
):
|
| 244 |
+
save_path = os.path.join(os.getcwd(), "tad_text_classification.result.json")
|
| 245 |
+
|
| 246 |
+
target_file = detect_infer_dataset(target_file, task="text_defense")
|
| 247 |
+
if not target_file:
|
| 248 |
+
raise FileNotFoundError("Can not find inference datasets!")
|
| 249 |
+
|
| 250 |
+
if hasattr(BERTTADModelList, self.opt.model.__name__):
|
| 251 |
+
dataset = BERTTADDataset(tokenizer=self.tokenizer, opt=self.opt)
|
| 252 |
+
|
| 253 |
+
dataset.prepare_infer_dataset(target_file, ignore_error=ignore_error)
|
| 254 |
+
self.infer_dataloader = DataLoader(
|
| 255 |
+
dataset=dataset,
|
| 256 |
+
batch_size=self.opt.eval_batch_size,
|
| 257 |
+
pin_memory=True,
|
| 258 |
+
shuffle=False,
|
| 259 |
+
)
|
| 260 |
+
return self._infer(
|
| 261 |
+
save_path=save_path if save_result else None,
|
| 262 |
+
print_result=print_result,
|
| 263 |
+
defense=defense,
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
def infer(
|
| 267 |
+
self,
|
| 268 |
+
text: str = None,
|
| 269 |
+
print_result=True,
|
| 270 |
+
ignore_error=True,
|
| 271 |
+
defense: str = None,
|
| 272 |
+
):
|
| 273 |
+
if hasattr(BERTTADModelList, self.opt.model.__name__):
|
| 274 |
+
dataset = BERTTADDataset(tokenizer=self.tokenizer, opt=self.opt)
|
| 275 |
+
|
| 276 |
+
if text:
|
| 277 |
+
dataset.prepare_infer_sample(text, ignore_error=ignore_error)
|
| 278 |
+
else:
|
| 279 |
+
raise RuntimeError("Please specify your datasets path!")
|
| 280 |
+
self.infer_dataloader = DataLoader(
|
| 281 |
+
dataset=dataset, batch_size=self.opt.eval_batch_size, shuffle=False
|
| 282 |
+
)
|
| 283 |
+
return self._infer(print_result=print_result, defense=defense)[0]
|
| 284 |
+
|
| 285 |
+
def _infer(self, save_path=None, print_result=True, defense=None):
|
| 286 |
+
_params = filter(lambda p: p.requires_grad, self.model.parameters())
|
| 287 |
+
|
| 288 |
+
correct = {True: "Correct", False: "Wrong"}
|
| 289 |
+
results = []
|
| 290 |
+
|
| 291 |
+
with torch.no_grad():
|
| 292 |
+
self.model.eval()
|
| 293 |
+
n_correct = 0
|
| 294 |
+
n_labeled = 0
|
| 295 |
+
|
| 296 |
+
n_advdet_correct = 0
|
| 297 |
+
n_advdet_labeled = 0
|
| 298 |
+
if len(self.infer_dataloader.dataset) >= 100:
|
| 299 |
+
it = tqdm.tqdm(self.infer_dataloader, postfix="inferring...")
|
| 300 |
+
else:
|
| 301 |
+
it = self.infer_dataloader
|
| 302 |
+
for _, sample in enumerate(it):
|
| 303 |
+
inputs = [
|
| 304 |
+
sample[col].to(self.opt.device) for col in self.opt.inputs_cols
|
| 305 |
+
]
|
| 306 |
+
outputs = self.model(inputs)
|
| 307 |
+
logits, advdet_logits, adv_tr_logits = (
|
| 308 |
+
outputs["sent_logits"],
|
| 309 |
+
outputs["advdet_logits"],
|
| 310 |
+
outputs["adv_tr_logits"],
|
| 311 |
+
)
|
| 312 |
+
probs, advdet_probs, adv_tr_probs = (
|
| 313 |
+
torch.softmax(logits, dim=-1),
|
| 314 |
+
torch.softmax(advdet_logits, dim=-1),
|
| 315 |
+
torch.softmax(adv_tr_logits, dim=-1),
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
for i, (prob, advdet_prob, adv_tr_prob) in enumerate(
|
| 319 |
+
zip(probs, advdet_probs, adv_tr_probs)
|
| 320 |
+
):
|
| 321 |
+
text_raw = sample["text_raw"][i]
|
| 322 |
+
|
| 323 |
+
pred_label = int(prob.argmax(axis=-1))
|
| 324 |
+
pred_is_adv_label = int(advdet_prob.argmax(axis=-1))
|
| 325 |
+
pred_adv_tr_label = int(adv_tr_prob.argmax(axis=-1))
|
| 326 |
+
ref_label = (
|
| 327 |
+
int(sample["label"][i])
|
| 328 |
+
if int(sample["label"][i]) in self.opt.index_to_label
|
| 329 |
+
else ""
|
| 330 |
+
)
|
| 331 |
+
ref_is_adv_label = (
|
| 332 |
+
int(sample["is_adv"][i])
|
| 333 |
+
if int(sample["is_adv"][i]) in self.opt.index_to_is_adv
|
| 334 |
+
else ""
|
| 335 |
+
)
|
| 336 |
+
ref_adv_tr_label = (
|
| 337 |
+
int(sample["adv_train_label"][i])
|
| 338 |
+
if int(sample["adv_train_label"][i])
|
| 339 |
+
in self.opt.index_to_adv_train_label
|
| 340 |
+
else ""
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
if self.cal_perplexity:
|
| 344 |
+
ids = self.MLM_tokenizer(text_raw, return_tensors="pt")
|
| 345 |
+
ids["labels"] = ids["input_ids"].clone()
|
| 346 |
+
ids = ids.to(self.opt.device)
|
| 347 |
+
loss = self.MLM(**ids)["loss"]
|
| 348 |
+
perplexity = float(torch.exp(loss / ids["input_ids"].size(1)))
|
| 349 |
+
else:
|
| 350 |
+
perplexity = "N.A."
|
| 351 |
+
|
| 352 |
+
result = {
|
| 353 |
+
"text": text_raw,
|
| 354 |
+
"label": self.opt.index_to_label[pred_label],
|
| 355 |
+
"probs": prob.cpu().numpy(),
|
| 356 |
+
"confidence": float(max(prob)),
|
| 357 |
+
"ref_label": self.opt.index_to_label[ref_label]
|
| 358 |
+
if isinstance(ref_label, int)
|
| 359 |
+
else ref_label,
|
| 360 |
+
"ref_label_check": correct[pred_label == ref_label]
|
| 361 |
+
if ref_label != -100
|
| 362 |
+
else "",
|
| 363 |
+
"is_fixed": False,
|
| 364 |
+
"is_adv_label": self.opt.index_to_is_adv[pred_is_adv_label],
|
| 365 |
+
"is_adv_probs": advdet_prob.cpu().numpy(),
|
| 366 |
+
"is_adv_confidence": float(max(advdet_prob)),
|
| 367 |
+
"ref_is_adv_label": self.opt.index_to_is_adv[ref_is_adv_label]
|
| 368 |
+
if isinstance(ref_is_adv_label, int)
|
| 369 |
+
else ref_is_adv_label,
|
| 370 |
+
"ref_is_adv_check": correct[
|
| 371 |
+
pred_is_adv_label == ref_is_adv_label
|
| 372 |
+
]
|
| 373 |
+
if ref_is_adv_label != -100
|
| 374 |
+
and isinstance(ref_is_adv_label, int)
|
| 375 |
+
else "",
|
| 376 |
+
"pred_adv_tr_label": self.opt.index_to_label[pred_adv_tr_label],
|
| 377 |
+
"ref_adv_tr_label": self.opt.index_to_label[ref_adv_tr_label],
|
| 378 |
+
"perplexity": perplexity,
|
| 379 |
+
}
|
| 380 |
+
if defense:
|
| 381 |
+
try:
|
| 382 |
+
if not hasattr(self, "sent_attacker"):
|
| 383 |
+
self.sent_attacker = init_attacker(
|
| 384 |
+
self, defense.lower()
|
| 385 |
+
)
|
| 386 |
+
if result["is_adv_label"] == "1":
|
| 387 |
+
res = self.sent_attacker.attacker.simple_attack(
|
| 388 |
+
text_raw, int(result["label"])
|
| 389 |
+
)
|
| 390 |
+
new_infer_res = self.infer(
|
| 391 |
+
res.perturbed_result.attacked_text.text,
|
| 392 |
+
print_result=False,
|
| 393 |
+
)
|
| 394 |
+
result["perturbed_label"] = result["label"]
|
| 395 |
+
result["label"] = new_infer_res["label"]
|
| 396 |
+
result["probs"] = new_infer_res["probs"]
|
| 397 |
+
result["ref_label_check"] = (
|
| 398 |
+
correct[int(result["label"]) == ref_label]
|
| 399 |
+
if ref_label != -100
|
| 400 |
+
else ""
|
| 401 |
+
)
|
| 402 |
+
result[
|
| 403 |
+
"restored_text"
|
| 404 |
+
] = res.perturbed_result.attacked_text.text
|
| 405 |
+
result["is_fixed"] = True
|
| 406 |
+
else:
|
| 407 |
+
result["restored_text"] = ""
|
| 408 |
+
result["is_fixed"] = False
|
| 409 |
+
|
| 410 |
+
except Exception as e:
|
| 411 |
+
print(
|
| 412 |
+
"Error:{}, try install TextAttack and tensorflow_text after 10 seconds...".format(
|
| 413 |
+
e
|
| 414 |
+
)
|
| 415 |
+
)
|
| 416 |
+
time.sleep(10)
|
| 417 |
+
raise RuntimeError("Installation done, please run again...")
|
| 418 |
+
|
| 419 |
+
if ref_label != -100:
|
| 420 |
+
n_labeled += 1
|
| 421 |
+
|
| 422 |
+
if result["label"] == result["ref_label"]:
|
| 423 |
+
n_correct += 1
|
| 424 |
+
|
| 425 |
+
if ref_is_adv_label != -100:
|
| 426 |
+
n_advdet_labeled += 1
|
| 427 |
+
if ref_is_adv_label == pred_is_adv_label:
|
| 428 |
+
n_advdet_correct += 1
|
| 429 |
+
|
| 430 |
+
results.append(result)
|
| 431 |
+
|
| 432 |
+
try:
|
| 433 |
+
if print_result:
|
| 434 |
+
for ex_id, result in enumerate(results):
|
| 435 |
+
text_printing = result["text"][:]
|
| 436 |
+
text_info = ""
|
| 437 |
+
if result["label"] != "-100":
|
| 438 |
+
if not result["ref_label"]:
|
| 439 |
+
text_info += " -> <CLS:{}(ref:{} confidence:{})>".format(
|
| 440 |
+
result["label"],
|
| 441 |
+
result["ref_label"],
|
| 442 |
+
result["confidence"],
|
| 443 |
+
)
|
| 444 |
+
elif result["label"] == result["ref_label"]:
|
| 445 |
+
text_info += colored(
|
| 446 |
+
" -> <CLS:{}(ref:{} confidence:{})>".format(
|
| 447 |
+
result["label"],
|
| 448 |
+
result["ref_label"],
|
| 449 |
+
result["confidence"],
|
| 450 |
+
),
|
| 451 |
+
"green",
|
| 452 |
+
)
|
| 453 |
+
else:
|
| 454 |
+
text_info += colored(
|
| 455 |
+
" -> <CLS:{}(ref:{} confidence:{})>".format(
|
| 456 |
+
result["label"],
|
| 457 |
+
result["ref_label"],
|
| 458 |
+
result["confidence"],
|
| 459 |
+
),
|
| 460 |
+
"red",
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
# AdvDet
|
| 464 |
+
if result["is_adv_label"] != "-100":
|
| 465 |
+
if not result["ref_is_adv_label"]:
|
| 466 |
+
text_info += " -> <AdvDet:{}(ref:{} confidence:{})>".format(
|
| 467 |
+
result["is_adv_label"],
|
| 468 |
+
result["ref_is_adv_check"],
|
| 469 |
+
result["is_adv_confidence"],
|
| 470 |
+
)
|
| 471 |
+
elif result["is_adv_label"] == result["ref_is_adv_label"]:
|
| 472 |
+
text_info += colored(
|
| 473 |
+
" -> <AdvDet:{}(ref:{} confidence:{})>".format(
|
| 474 |
+
result["is_adv_label"],
|
| 475 |
+
result["ref_is_adv_label"],
|
| 476 |
+
result["is_adv_confidence"],
|
| 477 |
+
),
|
| 478 |
+
"green",
|
| 479 |
+
)
|
| 480 |
+
else:
|
| 481 |
+
text_info += colored(
|
| 482 |
+
" -> <AdvDet:{}(ref:{} confidence:{})>".format(
|
| 483 |
+
result["is_adv_label"],
|
| 484 |
+
result["ref_is_adv_label"],
|
| 485 |
+
result["is_adv_confidence"],
|
| 486 |
+
),
|
| 487 |
+
"red",
|
| 488 |
+
)
|
| 489 |
+
text_printing += text_info
|
| 490 |
+
if self.cal_perplexity:
|
| 491 |
+
text_printing += colored(
|
| 492 |
+
" --> <perplexity:{}>".format(result["perplexity"]),
|
| 493 |
+
"yellow",
|
| 494 |
+
)
|
| 495 |
+
print("Example {}: {}".format(ex_id, text_printing))
|
| 496 |
+
if save_path:
|
| 497 |
+
with open(save_path, "w", encoding="utf8") as fout:
|
| 498 |
+
json.dump(str(results), fout, ensure_ascii=False)
|
| 499 |
+
print("inference result saved in: {}".format(save_path))
|
| 500 |
+
except Exception as e:
|
| 501 |
+
print("Can not save result: {}, Exception: {}".format(text_raw, e))
|
| 502 |
+
|
| 503 |
+
if len(results) > 1:
|
| 504 |
+
print(
|
| 505 |
+
"CLS Acc:{}%".format(100 * n_correct / n_labeled if n_labeled else "")
|
| 506 |
+
)
|
| 507 |
+
print(
|
| 508 |
+
"AdvDet Acc:{}%".format(
|
| 509 |
+
100 * n_advdet_correct / n_advdet_labeled
|
| 510 |
+
if n_advdet_labeled
|
| 511 |
+
else ""
|
| 512 |
+
)
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
return results
|
| 516 |
+
|
| 517 |
+
def clear_input_samples(self):
|
| 518 |
+
self.dataset.all_data = []
|
anonymous_demo/functional/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from anonymous_demo.functional.checkpoint.checkpoint_manager import TADCheckpointManager
|
| 2 |
+
|
| 3 |
+
from anonymous_demo.functional.config import TADConfigManager
|
anonymous_demo/functional/checkpoint/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .checkpoint_manager import TADCheckpointManager
|
anonymous_demo/functional/checkpoint/checkpoint_manager.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from findfile import find_file
|
| 3 |
+
|
| 4 |
+
from anonymous_demo.core.tad.prediction.tad_classifier import TADTextClassifier
|
| 5 |
+
from anonymous_demo.utils.demo_utils import retry
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class CheckpointManager:
|
| 9 |
+
pass
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class TADCheckpointManager(CheckpointManager):
|
| 13 |
+
@staticmethod
|
| 14 |
+
@retry
|
| 15 |
+
def get_tad_text_classifier(checkpoint: str = None, eval_batch_size=128, **kwargs):
|
| 16 |
+
tad_text_classifier = TADTextClassifier(
|
| 17 |
+
checkpoint, eval_batch_size=eval_batch_size, **kwargs
|
| 18 |
+
)
|
| 19 |
+
return tad_text_classifier
|
anonymous_demo/functional/config/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .tad_config_manager import TADConfigManager
|
anonymous_demo/functional/config/config_manager.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from argparse import Namespace
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
one_shot_messages = set()
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def config_check(args):
|
| 9 |
+
pass
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ConfigManager(Namespace):
|
| 13 |
+
def __init__(self, args=None, **kwargs):
|
| 14 |
+
"""
|
| 15 |
+
The ConfigManager is a subclass of argparse.Namespace and based on parameter dict and count the call-frequency of each parameter
|
| 16 |
+
:param args: A parameter dict
|
| 17 |
+
:param kwargs: Same param as Namespce
|
| 18 |
+
"""
|
| 19 |
+
if not args:
|
| 20 |
+
args = {}
|
| 21 |
+
super().__init__(**kwargs)
|
| 22 |
+
|
| 23 |
+
if isinstance(args, Namespace):
|
| 24 |
+
self.args = vars(args)
|
| 25 |
+
self.args_call_count = {arg: 0 for arg in vars(args)}
|
| 26 |
+
else:
|
| 27 |
+
self.args = args
|
| 28 |
+
self.args_call_count = {arg: 0 for arg in args}
|
| 29 |
+
|
| 30 |
+
def __getattribute__(self, arg_name):
|
| 31 |
+
if arg_name == "args" or arg_name == "args_call_count":
|
| 32 |
+
return super().__getattribute__(arg_name)
|
| 33 |
+
try:
|
| 34 |
+
value = super().__getattribute__("args")[arg_name]
|
| 35 |
+
args_call_count = super().__getattribute__("args_call_count")
|
| 36 |
+
args_call_count[arg_name] += 1
|
| 37 |
+
super().__setattr__("args_call_count", args_call_count)
|
| 38 |
+
return value
|
| 39 |
+
|
| 40 |
+
except Exception as e:
|
| 41 |
+
return super().__getattribute__(arg_name)
|
| 42 |
+
|
| 43 |
+
def __setattr__(self, arg_name, value):
|
| 44 |
+
if arg_name == "args" or arg_name == "args_call_count":
|
| 45 |
+
super().__setattr__(arg_name, value)
|
| 46 |
+
return
|
| 47 |
+
try:
|
| 48 |
+
args = super().__getattribute__("args")
|
| 49 |
+
args[arg_name] = value
|
| 50 |
+
super().__setattr__("args", args)
|
| 51 |
+
args_call_count = super().__getattribute__("args_call_count")
|
| 52 |
+
|
| 53 |
+
if arg_name in args_call_count:
|
| 54 |
+
# args_call_count[arg_name] += 1
|
| 55 |
+
super().__setattr__("args_call_count", args_call_count)
|
| 56 |
+
|
| 57 |
+
else:
|
| 58 |
+
args_call_count[arg_name] = 0
|
| 59 |
+
super().__setattr__("args_call_count", args_call_count)
|
| 60 |
+
|
| 61 |
+
except Exception as e:
|
| 62 |
+
super().__setattr__(arg_name, value)
|
| 63 |
+
|
| 64 |
+
config_check(args)
|
anonymous_demo/functional/config/tad_config_manager.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
|
| 3 |
+
from anonymous_demo.functional.config.config_manager import ConfigManager
|
| 4 |
+
from anonymous_demo.core.tad.classic.__bert__.models import TADBERT
|
| 5 |
+
|
| 6 |
+
_tad_config_template = {
|
| 7 |
+
"model": TADBERT,
|
| 8 |
+
"optimizer": "adamw",
|
| 9 |
+
"learning_rate": 0.00002,
|
| 10 |
+
"patience": 99999,
|
| 11 |
+
"pretrained_bert": "microsoft/mdeberta-v3-base",
|
| 12 |
+
"cache_dataset": True,
|
| 13 |
+
"warmup_step": -1,
|
| 14 |
+
"show_metric": False,
|
| 15 |
+
"max_seq_len": 80,
|
| 16 |
+
"dropout": 0,
|
| 17 |
+
"l2reg": 0.000001,
|
| 18 |
+
"num_epoch": 10,
|
| 19 |
+
"batch_size": 16,
|
| 20 |
+
"initializer": "xavier_uniform_",
|
| 21 |
+
"seed": 52,
|
| 22 |
+
"polarities_dim": 3,
|
| 23 |
+
"log_step": 10,
|
| 24 |
+
"evaluate_begin": 0,
|
| 25 |
+
"cross_validate_fold": -1,
|
| 26 |
+
"use_amp": False,
|
| 27 |
+
# split train and test datasets into 5 folds and repeat 3 training
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
_tad_config_base = {
|
| 31 |
+
"model": TADBERT,
|
| 32 |
+
"optimizer": "adamw",
|
| 33 |
+
"learning_rate": 0.00002,
|
| 34 |
+
"pretrained_bert": "microsoft/deberta-v3-base",
|
| 35 |
+
"cache_dataset": True,
|
| 36 |
+
"warmup_step": -1,
|
| 37 |
+
"show_metric": False,
|
| 38 |
+
"max_seq_len": 80,
|
| 39 |
+
"patience": 99999,
|
| 40 |
+
"dropout": 0,
|
| 41 |
+
"l2reg": 0.000001,
|
| 42 |
+
"num_epoch": 10,
|
| 43 |
+
"batch_size": 16,
|
| 44 |
+
"initializer": "xavier_uniform_",
|
| 45 |
+
"seed": 52,
|
| 46 |
+
"polarities_dim": 3,
|
| 47 |
+
"log_step": 10,
|
| 48 |
+
"evaluate_begin": 0,
|
| 49 |
+
"cross_validate_fold": -1
|
| 50 |
+
# split train and test datasets into 5 folds and repeat 3 training
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
_tad_config_english = {
|
| 54 |
+
"model": TADBERT,
|
| 55 |
+
"optimizer": "adamw",
|
| 56 |
+
"learning_rate": 0.00002,
|
| 57 |
+
"patience": 99999,
|
| 58 |
+
"pretrained_bert": "microsoft/deberta-v3-base",
|
| 59 |
+
"cache_dataset": True,
|
| 60 |
+
"warmup_step": -1,
|
| 61 |
+
"show_metric": False,
|
| 62 |
+
"max_seq_len": 80,
|
| 63 |
+
"dropout": 0,
|
| 64 |
+
"l2reg": 0.000001,
|
| 65 |
+
"num_epoch": 10,
|
| 66 |
+
"batch_size": 16,
|
| 67 |
+
"initializer": "xavier_uniform_",
|
| 68 |
+
"seed": 52,
|
| 69 |
+
"polarities_dim": 3,
|
| 70 |
+
"log_step": 10,
|
| 71 |
+
"evaluate_begin": 0,
|
| 72 |
+
"cross_validate_fold": -1
|
| 73 |
+
# split train and test datasets into 5 folds and repeat 3 training
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
_tad_config_multilingual = {
|
| 77 |
+
"model": TADBERT,
|
| 78 |
+
"optimizer": "adamw",
|
| 79 |
+
"learning_rate": 0.00002,
|
| 80 |
+
"patience": 99999,
|
| 81 |
+
"pretrained_bert": "microsoft/mdeberta-v3-base",
|
| 82 |
+
"cache_dataset": True,
|
| 83 |
+
"warmup_step": -1,
|
| 84 |
+
"show_metric": False,
|
| 85 |
+
"max_seq_len": 80,
|
| 86 |
+
"dropout": 0,
|
| 87 |
+
"l2reg": 0.000001,
|
| 88 |
+
"num_epoch": 10,
|
| 89 |
+
"batch_size": 16,
|
| 90 |
+
"initializer": "xavier_uniform_",
|
| 91 |
+
"seed": 52,
|
| 92 |
+
"polarities_dim": 3,
|
| 93 |
+
"log_step": 10,
|
| 94 |
+
"evaluate_begin": 0,
|
| 95 |
+
"cross_validate_fold": -1
|
| 96 |
+
# split train and test datasets into 5 folds and repeat 3 training
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
_tad_config_chinese = {
|
| 100 |
+
"model": TADBERT,
|
| 101 |
+
"optimizer": "adamw",
|
| 102 |
+
"learning_rate": 0.00002,
|
| 103 |
+
"patience": 99999,
|
| 104 |
+
"cache_dataset": True,
|
| 105 |
+
"warmup_step": -1,
|
| 106 |
+
"show_metric": False,
|
| 107 |
+
"pretrained_bert": "bert-base-chinese",
|
| 108 |
+
"max_seq_len": 80,
|
| 109 |
+
"dropout": 0,
|
| 110 |
+
"l2reg": 0.000001,
|
| 111 |
+
"num_epoch": 10,
|
| 112 |
+
"batch_size": 16,
|
| 113 |
+
"initializer": "xavier_uniform_",
|
| 114 |
+
"seed": 52,
|
| 115 |
+
"polarities_dim": 3,
|
| 116 |
+
"log_step": 10,
|
| 117 |
+
"evaluate_begin": 0,
|
| 118 |
+
"cross_validate_fold": -1
|
| 119 |
+
# split train and test datasets into 5 folds and repeat 3 training
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class TADConfigManager(ConfigManager):
|
| 124 |
+
def __init__(self, args, **kwargs):
|
| 125 |
+
"""
|
| 126 |
+
Available Params: {'model': BERT,
|
| 127 |
+
'optimizer': "adamw",
|
| 128 |
+
'learning_rate': 0.00002,
|
| 129 |
+
'pretrained_bert': "roberta-base",
|
| 130 |
+
'cache_dataset': True,
|
| 131 |
+
'warmup_step': -1,
|
| 132 |
+
'show_metric': False,
|
| 133 |
+
'max_seq_len': 80,
|
| 134 |
+
'patience': 99999,
|
| 135 |
+
'dropout': 0,
|
| 136 |
+
'l2reg': 0.000001,
|
| 137 |
+
'num_epoch': 10,
|
| 138 |
+
'batch_size': 16,
|
| 139 |
+
'initializer': 'xavier_uniform_',
|
| 140 |
+
'seed': {52, 25}
|
| 141 |
+
'embed_dim': 768,
|
| 142 |
+
'hidden_dim': 768,
|
| 143 |
+
'polarities_dim': 3,
|
| 144 |
+
'log_step': 10,
|
| 145 |
+
'evaluate_begin': 0,
|
| 146 |
+
'cross_validate_fold': -1 # split train and test datasets into 5 folds and repeat 3 training
|
| 147 |
+
}
|
| 148 |
+
:param args:
|
| 149 |
+
:param kwargs:
|
| 150 |
+
"""
|
| 151 |
+
super().__init__(args, **kwargs)
|
| 152 |
+
|
| 153 |
+
@staticmethod
|
| 154 |
+
def set_tad_config(configType: str, newitem: dict):
|
| 155 |
+
if isinstance(newitem, dict):
|
| 156 |
+
if configType == "template":
|
| 157 |
+
_tad_config_template.update(newitem)
|
| 158 |
+
elif configType == "base":
|
| 159 |
+
_tad_config_base.update(newitem)
|
| 160 |
+
elif configType == "english":
|
| 161 |
+
_tad_config_english.update(newitem)
|
| 162 |
+
elif configType == "chinese":
|
| 163 |
+
_tad_config_chinese.update(newitem)
|
| 164 |
+
elif configType == "multilingual":
|
| 165 |
+
_tad_config_multilingual.update(newitem)
|
| 166 |
+
elif configType == "glove":
|
| 167 |
+
_tad_config_glove.update(newitem)
|
| 168 |
+
else:
|
| 169 |
+
raise ValueError(
|
| 170 |
+
"Wrong value of config type supplied, please use one from following type: template, base, english, chinese, multilingual, glove"
|
| 171 |
+
)
|
| 172 |
+
else:
|
| 173 |
+
raise TypeError(
|
| 174 |
+
"Wrong type of new config item supplied, please use dict e.g.{'NewConfig': NewValue}"
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
@staticmethod
|
| 178 |
+
def set_tad_config_template(newitem):
|
| 179 |
+
TADConfigManager.set_tad_config("template", newitem)
|
| 180 |
+
|
| 181 |
+
@staticmethod
|
| 182 |
+
def set_tad_config_base(newitem):
|
| 183 |
+
TADConfigManager.set_tad_config("base", newitem)
|
| 184 |
+
|
| 185 |
+
@staticmethod
|
| 186 |
+
def set_tad_config_english(newitem):
|
| 187 |
+
TADConfigManager.set_tad_config("english", newitem)
|
| 188 |
+
|
| 189 |
+
@staticmethod
|
| 190 |
+
def set_tad_config_chinese(newitem):
|
| 191 |
+
TADConfigManager.set_tad_config("chinese", newitem)
|
| 192 |
+
|
| 193 |
+
@staticmethod
|
| 194 |
+
def set_tad_config_multilingual(newitem):
|
| 195 |
+
TADConfigManager.set_tad_config("multilingual", newitem)
|
| 196 |
+
|
| 197 |
+
@staticmethod
|
| 198 |
+
def set_tad_config_glove(newitem):
|
| 199 |
+
TADConfigManager.set_tad_config("glove", newitem)
|
| 200 |
+
|
| 201 |
+
@staticmethod
|
| 202 |
+
def get_tad_config_template() -> ConfigManager:
|
| 203 |
+
_tad_config_template.update(_tad_config_template)
|
| 204 |
+
return TADConfigManager(copy.deepcopy(_tad_config_template))
|
| 205 |
+
|
| 206 |
+
@staticmethod
|
| 207 |
+
def get_tad_config_base() -> ConfigManager:
|
| 208 |
+
_tad_config_template.update(_tad_config_base)
|
| 209 |
+
return TADConfigManager(copy.deepcopy(_tad_config_template))
|
| 210 |
+
|
| 211 |
+
@staticmethod
|
| 212 |
+
def get_tad_config_english() -> ConfigManager:
|
| 213 |
+
_tad_config_template.update(_tad_config_english)
|
| 214 |
+
return TADConfigManager(copy.deepcopy(_tad_config_template))
|
| 215 |
+
|
| 216 |
+
@staticmethod
|
| 217 |
+
def get_tad_config_chinese() -> ConfigManager:
|
| 218 |
+
_tad_config_template.update(_tad_config_chinese)
|
| 219 |
+
return TADConfigManager(copy.deepcopy(_tad_config_template))
|
| 220 |
+
|
| 221 |
+
@staticmethod
|
| 222 |
+
def get_tad_config_multilingual() -> ConfigManager:
|
| 223 |
+
_tad_config_template.update(_tad_config_multilingual)
|
| 224 |
+
return TADConfigManager(copy.deepcopy(_tad_config_template))
|
| 225 |
+
|
| 226 |
+
@staticmethod
|
| 227 |
+
def get_tad_config_glove() -> ConfigManager:
|
| 228 |
+
_tad_config_template.update(_tad_config_glove)
|
| 229 |
+
return TADConfigManager(copy.deepcopy(_tad_config_template))
|
anonymous_demo/functional/dataset/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from anonymous_demo.functional.dataset.dataset_manager import detect_infer_dataset
|
anonymous_demo/functional/dataset/dataset_manager.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from findfile import find_files, find_dir
|
| 3 |
+
|
| 4 |
+
filter_key_words = [
|
| 5 |
+
".py",
|
| 6 |
+
".md",
|
| 7 |
+
"readme",
|
| 8 |
+
"log",
|
| 9 |
+
"result",
|
| 10 |
+
"zip",
|
| 11 |
+
".state_dict",
|
| 12 |
+
".model",
|
| 13 |
+
".png",
|
| 14 |
+
"acc_",
|
| 15 |
+
"f1_",
|
| 16 |
+
".backup",
|
| 17 |
+
".bak",
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def detect_infer_dataset(dataset_path, task="apc"):
|
| 22 |
+
dataset_file = []
|
| 23 |
+
if isinstance(dataset_path, str) and os.path.isfile(dataset_path):
|
| 24 |
+
dataset_file.append(dataset_path)
|
| 25 |
+
return dataset_file
|
| 26 |
+
|
| 27 |
+
for d in dataset_path:
|
| 28 |
+
if not os.path.exists(d):
|
| 29 |
+
search_path = find_dir(
|
| 30 |
+
os.getcwd(),
|
| 31 |
+
[d, task, "dataset"],
|
| 32 |
+
exclude_key=filter_key_words,
|
| 33 |
+
disable_alert=False,
|
| 34 |
+
)
|
| 35 |
+
dataset_file += find_files(
|
| 36 |
+
search_path,
|
| 37 |
+
[".inference", d],
|
| 38 |
+
exclude_key=["train."] + filter_key_words,
|
| 39 |
+
)
|
| 40 |
+
else:
|
| 41 |
+
dataset_file += find_files(
|
| 42 |
+
d, [".inference", task], exclude_key=["train."] + filter_key_words
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
return dataset_file
|
anonymous_demo/network/__init__.py
ADDED
|
File without changes
|
anonymous_demo/network/lcf_pooler.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class LCF_Pooler(nn.Module):
|
| 7 |
+
def __init__(self, config):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.config = config
|
| 10 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 11 |
+
self.activation = nn.Tanh()
|
| 12 |
+
|
| 13 |
+
def forward(self, hidden_states, lcf_vec):
|
| 14 |
+
device = hidden_states.device
|
| 15 |
+
lcf_vec = lcf_vec.detach().cpu().numpy()
|
| 16 |
+
|
| 17 |
+
pooled_output = numpy.zeros(
|
| 18 |
+
(hidden_states.shape[0], hidden_states.shape[2]), dtype=numpy.float32
|
| 19 |
+
)
|
| 20 |
+
hidden_states = hidden_states.detach().cpu().numpy()
|
| 21 |
+
for i, vec in enumerate(lcf_vec):
|
| 22 |
+
lcf_ids = [j for j in range(len(vec)) if sum(vec[j] - 1.0) == 0]
|
| 23 |
+
pooled_output[i] = hidden_states[i][lcf_ids[len(lcf_ids) // 2]]
|
| 24 |
+
|
| 25 |
+
pooled_output = torch.Tensor(pooled_output).to(device)
|
| 26 |
+
pooled_output = self.dense(pooled_output)
|
| 27 |
+
pooled_output = self.activation(pooled_output)
|
| 28 |
+
return pooled_output
|
anonymous_demo/network/lsa.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from anonymous_demo.network.sa_encoder import Encoder
|
| 3 |
+
from torch import nn
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class LSA(nn.Module):
|
| 7 |
+
def __init__(self, bert, opt):
|
| 8 |
+
super(LSA, self).__init__()
|
| 9 |
+
self.opt = opt
|
| 10 |
+
|
| 11 |
+
self.encoder = Encoder(bert.config, opt)
|
| 12 |
+
self.encoder_left = Encoder(bert.config, opt)
|
| 13 |
+
self.encoder_right = Encoder(bert.config, opt)
|
| 14 |
+
self.linear_window_3h = nn.Linear(opt.embed_dim * 3, opt.embed_dim)
|
| 15 |
+
self.linear_window_2h = nn.Linear(opt.embed_dim * 2, opt.embed_dim)
|
| 16 |
+
self.eta1 = nn.Parameter(torch.tensor(self.opt.eta, dtype=torch.float))
|
| 17 |
+
self.eta2 = nn.Parameter(torch.tensor(self.opt.eta, dtype=torch.float))
|
| 18 |
+
|
| 19 |
+
def forward(
|
| 20 |
+
self,
|
| 21 |
+
global_context_features,
|
| 22 |
+
spc_mask_vec,
|
| 23 |
+
lcf_matrix,
|
| 24 |
+
left_lcf_matrix,
|
| 25 |
+
right_lcf_matrix,
|
| 26 |
+
):
|
| 27 |
+
masked_global_context_features = torch.mul(
|
| 28 |
+
spc_mask_vec, global_context_features
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
# # --------------------------------------------------- #
|
| 32 |
+
lcf_features = torch.mul(global_context_features, lcf_matrix)
|
| 33 |
+
lcf_features = self.encoder(lcf_features)
|
| 34 |
+
# # --------------------------------------------------- #
|
| 35 |
+
left_lcf_features = torch.mul(masked_global_context_features, left_lcf_matrix)
|
| 36 |
+
left_lcf_features = self.encoder_left(left_lcf_features)
|
| 37 |
+
# # --------------------------------------------------- #
|
| 38 |
+
right_lcf_features = torch.mul(masked_global_context_features, right_lcf_matrix)
|
| 39 |
+
right_lcf_features = self.encoder_right(right_lcf_features)
|
| 40 |
+
# # --------------------------------------------------- #
|
| 41 |
+
if "lr" == self.opt.window or "rl" == self.opt.window:
|
| 42 |
+
if self.eta1 <= 0 and self.opt.eta != -1:
|
| 43 |
+
torch.nn.init.uniform_(self.eta1)
|
| 44 |
+
print("reset eta1 to: {}".format(self.eta1.item()))
|
| 45 |
+
if self.eta2 <= 0 and self.opt.eta != -1:
|
| 46 |
+
torch.nn.init.uniform_(self.eta2)
|
| 47 |
+
print("reset eta2 to: {}".format(self.eta2.item()))
|
| 48 |
+
if self.opt.eta >= 0:
|
| 49 |
+
cat_features = torch.cat(
|
| 50 |
+
(
|
| 51 |
+
lcf_features,
|
| 52 |
+
self.eta1 * left_lcf_features,
|
| 53 |
+
self.eta2 * right_lcf_features,
|
| 54 |
+
),
|
| 55 |
+
-1,
|
| 56 |
+
)
|
| 57 |
+
else:
|
| 58 |
+
cat_features = torch.cat(
|
| 59 |
+
(lcf_features, left_lcf_features, right_lcf_features), -1
|
| 60 |
+
)
|
| 61 |
+
sent_out = self.linear_window_3h(cat_features)
|
| 62 |
+
elif "l" == self.opt.window:
|
| 63 |
+
sent_out = self.linear_window_2h(
|
| 64 |
+
torch.cat((lcf_features, self.eta1 * left_lcf_features), -1)
|
| 65 |
+
)
|
| 66 |
+
elif "r" == self.opt.window:
|
| 67 |
+
sent_out = self.linear_window_2h(
|
| 68 |
+
torch.cat((lcf_features, self.eta2 * right_lcf_features), -1)
|
| 69 |
+
)
|
| 70 |
+
else:
|
| 71 |
+
raise KeyError("Invalid parameter:", self.opt.window)
|
| 72 |
+
|
| 73 |
+
return sent_out
|
anonymous_demo/network/sa_encoder.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class BertSelfAttention(nn.Module):
|
| 9 |
+
def __init__(self, config):
|
| 10 |
+
super().__init__()
|
| 11 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
|
| 12 |
+
config, "embedding_size"
|
| 13 |
+
):
|
| 14 |
+
raise ValueError(
|
| 15 |
+
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
| 16 |
+
f"heads ({config.num_attention_heads})"
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
self.num_attention_heads = config.num_attention_heads
|
| 20 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 21 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 22 |
+
|
| 23 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
| 24 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
| 25 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
| 26 |
+
|
| 27 |
+
self.dropout = nn.Dropout(
|
| 28 |
+
config.attention_probs_dropout_prob
|
| 29 |
+
if hasattr(config, "attention_probs_dropout_prob")
|
| 30 |
+
else 0
|
| 31 |
+
)
|
| 32 |
+
self.position_embedding_type = getattr(
|
| 33 |
+
config, "position_embedding_type", "absolute"
|
| 34 |
+
)
|
| 35 |
+
if (
|
| 36 |
+
self.position_embedding_type == "relative_key"
|
| 37 |
+
or self.position_embedding_type == "relative_key_query"
|
| 38 |
+
):
|
| 39 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 40 |
+
self.distance_embedding = nn.Embedding(
|
| 41 |
+
2 * config.max_position_embeddings - 1, self.attention_head_size
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
self.is_decoder = config.is_decoder
|
| 45 |
+
|
| 46 |
+
def transpose_for_scores(self, x):
|
| 47 |
+
new_x_shape = x.size()[:-1] + (
|
| 48 |
+
self.num_attention_heads,
|
| 49 |
+
self.attention_head_size,
|
| 50 |
+
)
|
| 51 |
+
x = x.view(*new_x_shape)
|
| 52 |
+
return x.permute(0, 2, 1, 3)
|
| 53 |
+
|
| 54 |
+
def forward(
|
| 55 |
+
self,
|
| 56 |
+
hidden_states,
|
| 57 |
+
attention_mask=None,
|
| 58 |
+
head_mask=None,
|
| 59 |
+
encoder_hidden_states=None,
|
| 60 |
+
encoder_attention_mask=None,
|
| 61 |
+
past_key_value=None,
|
| 62 |
+
output_attentions=False,
|
| 63 |
+
):
|
| 64 |
+
mixed_query_layer = self.query(hidden_states)
|
| 65 |
+
|
| 66 |
+
# If this is instantiated as a cross-attention module, the keys
|
| 67 |
+
# and values come from an encoder; the attention mask needs to be
|
| 68 |
+
# such that the encoder's padding tokens are not attended to.
|
| 69 |
+
is_cross_attention = encoder_hidden_states is not None
|
| 70 |
+
|
| 71 |
+
if is_cross_attention and past_key_value is not None:
|
| 72 |
+
# reuse k,v, cross_attentions
|
| 73 |
+
key_layer = past_key_value[0]
|
| 74 |
+
value_layer = past_key_value[1]
|
| 75 |
+
attention_mask = encoder_attention_mask
|
| 76 |
+
elif is_cross_attention:
|
| 77 |
+
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
| 78 |
+
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
| 79 |
+
attention_mask = encoder_attention_mask
|
| 80 |
+
elif past_key_value is not None:
|
| 81 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 82 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 83 |
+
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
| 84 |
+
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
| 85 |
+
else:
|
| 86 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 87 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 88 |
+
|
| 89 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
| 90 |
+
|
| 91 |
+
if self.is_decoder:
|
| 92 |
+
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
| 93 |
+
# Further calls to cross_attention layer can then reuse all cross-attention
|
| 94 |
+
# key/value_states (first "if" case)
|
| 95 |
+
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
| 96 |
+
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
| 97 |
+
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
| 98 |
+
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
| 99 |
+
past_key_value = (key_layer, value_layer)
|
| 100 |
+
|
| 101 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 102 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 103 |
+
|
| 104 |
+
if (
|
| 105 |
+
self.position_embedding_type == "relative_key"
|
| 106 |
+
or self.position_embedding_type == "relative_key_query"
|
| 107 |
+
):
|
| 108 |
+
seq_length = hidden_states.size()[1]
|
| 109 |
+
position_ids_l = torch.arange(
|
| 110 |
+
seq_length, dtype=torch.long, device=hidden_states.device
|
| 111 |
+
).view(-1, 1)
|
| 112 |
+
position_ids_r = torch.arange(
|
| 113 |
+
seq_length, dtype=torch.long, device=hidden_states.device
|
| 114 |
+
).view(1, -1)
|
| 115 |
+
distance = position_ids_l - position_ids_r
|
| 116 |
+
positional_embedding = self.distance_embedding(
|
| 117 |
+
distance + self.max_position_embeddings - 1
|
| 118 |
+
)
|
| 119 |
+
positional_embedding = positional_embedding.to(
|
| 120 |
+
dtype=query_layer.dtype
|
| 121 |
+
) # fp16 compatibility
|
| 122 |
+
|
| 123 |
+
if self.position_embedding_type == "relative_key":
|
| 124 |
+
relative_position_scores = torch.einsum(
|
| 125 |
+
"bhld,lrd->bhlr", query_layer, positional_embedding
|
| 126 |
+
)
|
| 127 |
+
attention_scores = attention_scores + relative_position_scores
|
| 128 |
+
elif self.position_embedding_type == "relative_key_query":
|
| 129 |
+
relative_position_scores_query = torch.einsum(
|
| 130 |
+
"bhld,lrd->bhlr", query_layer, positional_embedding
|
| 131 |
+
)
|
| 132 |
+
relative_position_scores_key = torch.einsum(
|
| 133 |
+
"bhrd,lrd->bhlr", key_layer, positional_embedding
|
| 134 |
+
)
|
| 135 |
+
attention_scores = (
|
| 136 |
+
attention_scores
|
| 137 |
+
+ relative_position_scores_query
|
| 138 |
+
+ relative_position_scores_key
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
| 142 |
+
if attention_mask is not None:
|
| 143 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
| 144 |
+
attention_scores = attention_scores + attention_mask
|
| 145 |
+
|
| 146 |
+
# Normalize the attention scores to probabilities.
|
| 147 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
| 148 |
+
|
| 149 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 150 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 151 |
+
attention_probs = self.dropout(attention_probs)
|
| 152 |
+
|
| 153 |
+
# Mask heads if we want to
|
| 154 |
+
if head_mask is not None:
|
| 155 |
+
attention_probs = attention_probs * head_mask
|
| 156 |
+
|
| 157 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
| 158 |
+
|
| 159 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
| 160 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 161 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
| 162 |
+
|
| 163 |
+
outputs = (
|
| 164 |
+
(context_layer, attention_probs) if output_attentions else (context_layer,)
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
if self.is_decoder:
|
| 168 |
+
outputs = outputs + (past_key_value,)
|
| 169 |
+
return outputs
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class Encoder(nn.Module):
|
| 173 |
+
def __init__(self, config, opt, layer_num=1):
|
| 174 |
+
super(Encoder, self).__init__()
|
| 175 |
+
self.opt = opt
|
| 176 |
+
self.config = config
|
| 177 |
+
self.encoder = nn.ModuleList(
|
| 178 |
+
[SelfAttention(config, opt) for _ in range(layer_num)]
|
| 179 |
+
)
|
| 180 |
+
self.tanh = torch.nn.Tanh()
|
| 181 |
+
|
| 182 |
+
def forward(self, x):
|
| 183 |
+
for i, enc in enumerate(self.encoder):
|
| 184 |
+
x = self.tanh(enc(x)[0])
|
| 185 |
+
return x
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class SelfAttention(nn.Module):
|
| 189 |
+
def __init__(self, config, opt):
|
| 190 |
+
super(SelfAttention, self).__init__()
|
| 191 |
+
self.opt = opt
|
| 192 |
+
self.config = config
|
| 193 |
+
self.SA = BertSelfAttention(config)
|
| 194 |
+
|
| 195 |
+
def forward(self, inputs):
|
| 196 |
+
zero_vec = np.zeros((inputs.size(0), 1, 1, self.opt.max_seq_len))
|
| 197 |
+
zero_tensor = torch.tensor(zero_vec).float().to(inputs.device)
|
| 198 |
+
SA_out = self.SA(inputs, zero_tensor)
|
| 199 |
+
return SA_out
|
anonymous_demo/utils/__init__.py
ADDED
|
File without changes
|
anonymous_demo/utils/demo_utils.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import pickle
|
| 4 |
+
import signal
|
| 5 |
+
import threading
|
| 6 |
+
import time
|
| 7 |
+
import zipfile
|
| 8 |
+
|
| 9 |
+
import gdown
|
| 10 |
+
import numpy as np
|
| 11 |
+
import requests
|
| 12 |
+
import torch
|
| 13 |
+
import tqdm
|
| 14 |
+
from autocuda import auto_cuda, auto_cuda_name
|
| 15 |
+
from findfile import find_files, find_cwd_file, find_file
|
| 16 |
+
from termcolor import colored
|
| 17 |
+
from functools import wraps
|
| 18 |
+
|
| 19 |
+
from update_checker import parse_version
|
| 20 |
+
|
| 21 |
+
from anonymous_demo import __version__
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def save_args(config, save_path):
|
| 25 |
+
f = open(os.path.join(save_path), mode="w", encoding="utf8")
|
| 26 |
+
for arg in config.args:
|
| 27 |
+
if config.args_call_count[arg]:
|
| 28 |
+
f.write("{}: {}\n".format(arg, config.args[arg]))
|
| 29 |
+
f.close()
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def print_args(config, logger=None, mode=0):
|
| 33 |
+
args = [key for key in sorted(config.args.keys())]
|
| 34 |
+
for arg in args:
|
| 35 |
+
if logger:
|
| 36 |
+
logger.info(
|
| 37 |
+
"{0}:{1}\t-->\tCalling Count:{2}".format(
|
| 38 |
+
arg, config.args[arg], config.args_call_count[arg]
|
| 39 |
+
)
|
| 40 |
+
)
|
| 41 |
+
else:
|
| 42 |
+
print(
|
| 43 |
+
"{0}:{1}\t-->\tCalling Count:{2}".format(
|
| 44 |
+
arg, config.args[arg], config.args_call_count[arg]
|
| 45 |
+
)
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def check_and_fix_labels(label_set: set, label_name, all_data, opt):
|
| 50 |
+
if "-100" in label_set:
|
| 51 |
+
label_to_index = {
|
| 52 |
+
origin_label: int(idx) - 1 if origin_label != "-100" else -100
|
| 53 |
+
for origin_label, idx in zip(sorted(label_set), range(len(label_set)))
|
| 54 |
+
}
|
| 55 |
+
index_to_label = {
|
| 56 |
+
int(idx) - 1 if origin_label != "-100" else -100: origin_label
|
| 57 |
+
for origin_label, idx in zip(sorted(label_set), range(len(label_set)))
|
| 58 |
+
}
|
| 59 |
+
else:
|
| 60 |
+
label_to_index = {
|
| 61 |
+
origin_label: int(idx)
|
| 62 |
+
for origin_label, idx in zip(sorted(label_set), range(len(label_set)))
|
| 63 |
+
}
|
| 64 |
+
index_to_label = {
|
| 65 |
+
int(idx): origin_label
|
| 66 |
+
for origin_label, idx in zip(sorted(label_set), range(len(label_set)))
|
| 67 |
+
}
|
| 68 |
+
if "index_to_label" not in opt.args:
|
| 69 |
+
opt.index_to_label = index_to_label
|
| 70 |
+
opt.label_to_index = label_to_index
|
| 71 |
+
|
| 72 |
+
if opt.index_to_label != index_to_label:
|
| 73 |
+
opt.index_to_label.update(index_to_label)
|
| 74 |
+
opt.label_to_index.update(label_to_index)
|
| 75 |
+
num_label = {l: 0 for l in label_set}
|
| 76 |
+
num_label["Sum"] = len(all_data)
|
| 77 |
+
for item in all_data:
|
| 78 |
+
try:
|
| 79 |
+
num_label[item[label_name]] += 1
|
| 80 |
+
item[label_name] = label_to_index[item[label_name]]
|
| 81 |
+
except Exception as e:
|
| 82 |
+
# print(e)
|
| 83 |
+
num_label[item.polarity] += 1
|
| 84 |
+
item.polarity = label_to_index[item.polarity]
|
| 85 |
+
print("Dataset Label Details: {}".format(num_label))
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def check_and_fix_IOB_labels(label_map, opt):
|
| 89 |
+
index_to_IOB_label = {
|
| 90 |
+
int(label_map[origin_label]): origin_label for origin_label in label_map
|
| 91 |
+
}
|
| 92 |
+
opt.index_to_IOB_label = index_to_IOB_label
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def get_device(auto_device):
|
| 96 |
+
if isinstance(auto_device, str) and auto_device == "allcuda":
|
| 97 |
+
device = "cuda"
|
| 98 |
+
elif isinstance(auto_device, str):
|
| 99 |
+
device = auto_device
|
| 100 |
+
elif isinstance(auto_device, bool):
|
| 101 |
+
device = auto_cuda() if auto_device else "cpu"
|
| 102 |
+
else:
|
| 103 |
+
device = auto_cuda()
|
| 104 |
+
try:
|
| 105 |
+
torch.device(device)
|
| 106 |
+
except RuntimeError as e:
|
| 107 |
+
print(
|
| 108 |
+
colored("Device assignment error: {}, redirect to CPU".format(e), "red")
|
| 109 |
+
)
|
| 110 |
+
device = "cpu"
|
| 111 |
+
device_name = auto_cuda_name()
|
| 112 |
+
return device, device_name
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def _load_word_vec(path, word2idx=None, embed_dim=300):
|
| 116 |
+
fin = open(path, "r", encoding="utf-8", newline="\n", errors="ignore")
|
| 117 |
+
word_vec = {}
|
| 118 |
+
for line in tqdm.tqdm(fin.readlines(), postfix="Loading embedding file..."):
|
| 119 |
+
tokens = line.rstrip().split()
|
| 120 |
+
word, vec = " ".join(tokens[:-embed_dim]), tokens[-embed_dim:]
|
| 121 |
+
if word in word2idx.keys():
|
| 122 |
+
word_vec[word] = np.asarray(vec, dtype="float32")
|
| 123 |
+
return word_vec
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def build_embedding_matrix(word2idx, embed_dim, dat_fname, opt):
|
| 127 |
+
if not os.path.exists("run"):
|
| 128 |
+
os.makedirs("run")
|
| 129 |
+
embed_matrix_path = "run/{}".format(os.path.join(opt.dataset_name, dat_fname))
|
| 130 |
+
if os.path.exists(embed_matrix_path):
|
| 131 |
+
print(
|
| 132 |
+
colored(
|
| 133 |
+
"Loading cached embedding_matrix from {} (Please remove all cached files if there is any problem!)".format(
|
| 134 |
+
embed_matrix_path
|
| 135 |
+
),
|
| 136 |
+
"green",
|
| 137 |
+
)
|
| 138 |
+
)
|
| 139 |
+
embedding_matrix = pickle.load(open(embed_matrix_path, "rb"))
|
| 140 |
+
else:
|
| 141 |
+
glove_path = prepare_glove840_embedding(embed_matrix_path)
|
| 142 |
+
embedding_matrix = np.zeros((len(word2idx) + 2, embed_dim))
|
| 143 |
+
|
| 144 |
+
word_vec = _load_word_vec(glove_path, word2idx=word2idx, embed_dim=embed_dim)
|
| 145 |
+
|
| 146 |
+
for word, i in tqdm.tqdm(
|
| 147 |
+
word2idx.items(),
|
| 148 |
+
postfix=colored("Building embedding_matrix {}".format(dat_fname), "yellow"),
|
| 149 |
+
):
|
| 150 |
+
vec = word_vec.get(word)
|
| 151 |
+
if vec is not None:
|
| 152 |
+
embedding_matrix[i] = vec
|
| 153 |
+
pickle.dump(embedding_matrix, open(embed_matrix_path, "wb"))
|
| 154 |
+
return embedding_matrix
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def pad_and_truncate(
|
| 158 |
+
sequence, maxlen, dtype="int64", padding="post", truncating="post", value=0
|
| 159 |
+
):
|
| 160 |
+
x = (np.ones(maxlen) * value).astype(dtype)
|
| 161 |
+
if truncating == "pre":
|
| 162 |
+
trunc = sequence[-maxlen:]
|
| 163 |
+
else:
|
| 164 |
+
trunc = sequence[:maxlen]
|
| 165 |
+
trunc = np.asarray(trunc, dtype=dtype)
|
| 166 |
+
if padding == "post":
|
| 167 |
+
x[: len(trunc)] = trunc
|
| 168 |
+
else:
|
| 169 |
+
x[-len(trunc) :] = trunc
|
| 170 |
+
return x
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class TransformerConnectionError(ValueError):
|
| 174 |
+
def __init__(self):
|
| 175 |
+
pass
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def retry(f):
|
| 179 |
+
@wraps(f)
|
| 180 |
+
def decorated(*args, **kwargs):
|
| 181 |
+
count = 5
|
| 182 |
+
while count:
|
| 183 |
+
try:
|
| 184 |
+
return f(*args, **kwargs)
|
| 185 |
+
except (
|
| 186 |
+
TransformerConnectionError,
|
| 187 |
+
requests.exceptions.RequestException,
|
| 188 |
+
requests.exceptions.ConnectionError,
|
| 189 |
+
requests.exceptions.HTTPError,
|
| 190 |
+
requests.exceptions.ConnectTimeout,
|
| 191 |
+
requests.exceptions.ProxyError,
|
| 192 |
+
requests.exceptions.SSLError,
|
| 193 |
+
requests.exceptions.BaseHTTPError,
|
| 194 |
+
) as e:
|
| 195 |
+
print(colored("Training Exception: {}, will retry later".format(e)))
|
| 196 |
+
time.sleep(60)
|
| 197 |
+
count -= 1
|
| 198 |
+
|
| 199 |
+
return decorated
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def save_json(dic, save_path):
|
| 203 |
+
if isinstance(dic, str):
|
| 204 |
+
dic = eval(dic)
|
| 205 |
+
with open(save_path, "w", encoding="utf-8") as f:
|
| 206 |
+
# f.write(str(dict))
|
| 207 |
+
str_ = json.dumps(dic, ensure_ascii=False)
|
| 208 |
+
f.write(str_)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def load_json(save_path):
|
| 212 |
+
with open(save_path, "r", encoding="utf-8") as f:
|
| 213 |
+
data = f.readline().strip()
|
| 214 |
+
print(type(data), data)
|
| 215 |
+
dic = json.loads(data)
|
| 216 |
+
return dic
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def init_optimizer(optimizer):
|
| 220 |
+
optimizers = {
|
| 221 |
+
"adadelta": torch.optim.Adadelta, # default lr=1.0
|
| 222 |
+
"adagrad": torch.optim.Adagrad, # default lr=0.01
|
| 223 |
+
"adam": torch.optim.Adam, # default lr=0.001
|
| 224 |
+
"adamax": torch.optim.Adamax, # default lr=0.002
|
| 225 |
+
"asgd": torch.optim.ASGD, # default lr=0.01
|
| 226 |
+
"rmsprop": torch.optim.RMSprop, # default lr=0.01
|
| 227 |
+
"sgd": torch.optim.SGD,
|
| 228 |
+
"adamw": torch.optim.AdamW,
|
| 229 |
+
torch.optim.Adadelta: torch.optim.Adadelta, # default lr=1.0
|
| 230 |
+
torch.optim.Adagrad: torch.optim.Adagrad, # default lr=0.01
|
| 231 |
+
torch.optim.Adam: torch.optim.Adam, # default lr=0.001
|
| 232 |
+
torch.optim.Adamax: torch.optim.Adamax, # default lr=0.002
|
| 233 |
+
torch.optim.ASGD: torch.optim.ASGD, # default lr=0.01
|
| 234 |
+
torch.optim.RMSprop: torch.optim.RMSprop, # default lr=0.01
|
| 235 |
+
torch.optim.SGD: torch.optim.SGD,
|
| 236 |
+
torch.optim.AdamW: torch.optim.AdamW,
|
| 237 |
+
}
|
| 238 |
+
if optimizer in optimizers:
|
| 239 |
+
return optimizers[optimizer]
|
| 240 |
+
elif hasattr(torch.optim, optimizer.__name__):
|
| 241 |
+
return optimizer
|
| 242 |
+
else:
|
| 243 |
+
raise KeyError(
|
| 244 |
+
"Unsupported optimizer: {}. Please use string or the optimizer objects in torch.optim as your optimizer".format(
|
| 245 |
+
optimizer
|
| 246 |
+
)
|
| 247 |
+
)
|
anonymous_demo/utils/logger.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
import termcolor
|
| 7 |
+
|
| 8 |
+
today = time.strftime("%Y%m%d %H%M%S", time.localtime(time.time()))
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_logger(log_path, log_name="", log_type="training_log"):
|
| 12 |
+
if not log_path:
|
| 13 |
+
log_dir = os.path.join(log_path, "logs")
|
| 14 |
+
else:
|
| 15 |
+
log_dir = os.path.join(".", "logs")
|
| 16 |
+
|
| 17 |
+
full_path = os.path.join(log_dir, log_name + "_" + today)
|
| 18 |
+
if not os.path.exists(full_path):
|
| 19 |
+
os.makedirs(full_path)
|
| 20 |
+
log_path = os.path.join(full_path, "{}.log".format(log_type))
|
| 21 |
+
logger = logging.getLogger(log_name)
|
| 22 |
+
if not logger.handlers:
|
| 23 |
+
formatter = logging.Formatter("%(asctime)s %(levelname)s: %(message)s")
|
| 24 |
+
|
| 25 |
+
file_handler = logging.FileHandler(log_path, encoding="utf8")
|
| 26 |
+
file_handler.setFormatter(formatter)
|
| 27 |
+
file_handler.setLevel(logging.INFO)
|
| 28 |
+
|
| 29 |
+
console_handler = logging.StreamHandler(sys.stdout)
|
| 30 |
+
console_handler.formatter = formatter
|
| 31 |
+
console_handler.setLevel(logging.INFO)
|
| 32 |
+
|
| 33 |
+
logger.addHandler(file_handler)
|
| 34 |
+
logger.addHandler(console_handler)
|
| 35 |
+
|
| 36 |
+
logger.setLevel(logging.INFO)
|
| 37 |
+
|
| 38 |
+
return logger
|
checkpoints.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f77ae4a45785183900ee874cb318a16b0e2f173b31749a2555215aca93672f26
|
| 3 |
+
size 2456834455
|
text_defense/201.SST2/stsa.binary.dev.dat
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
text_defense/201.SST2/stsa.binary.test.dat
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
text_defense/201.SST2/stsa.binary.train.dat
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
text_defense/202.IMDB10K/imdb10k.test.dat
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
text_defense/202.IMDB10K/imdb10k.train.dat
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
text_defense/202.IMDB10K/imdb10k.valid.dat
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
text_defense/204.AGNews10K/AGNews10K.test.dat
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
text_defense/204.AGNews10K/AGNews10K.train.dat
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
text_defense/204.AGNews10K/AGNews10K.valid.dat
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
text_defense/206.Amazon_Review_Polarity10K/amazon.test.dat
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
text_defense/206.Amazon_Review_Polarity10K/amazon.train.dat
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
textattack/__init__.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Welcome to the API references for TextAttack!
|
| 2 |
+
|
| 3 |
+
What is TextAttack?
|
| 4 |
+
|
| 5 |
+
`TextAttack <https://github.com/QData/TextAttack>`__ is a Python framework for adversarial attacks, adversarial training, and data augmentation in NLP.
|
| 6 |
+
|
| 7 |
+
TextAttack makes experimenting with the robustness of NLP models seamless, fast, and easy. It's also useful for NLP model training, adversarial training, and data augmentation.
|
| 8 |
+
|
| 9 |
+
TextAttack provides components for common NLP tasks like sentence encoding, grammar-checking, and word replacement that can be used on their own.
|
| 10 |
+
"""
|
| 11 |
+
from .attack_args import AttackArgs, CommandLineAttackArgs
|
| 12 |
+
from .augment_args import AugmenterArgs
|
| 13 |
+
from .dataset_args import DatasetArgs
|
| 14 |
+
from .model_args import ModelArgs
|
| 15 |
+
from .training_args import TrainingArgs, CommandLineTrainingArgs
|
| 16 |
+
from .attack import Attack
|
| 17 |
+
from .attacker import Attacker
|
| 18 |
+
from .trainer import Trainer
|
| 19 |
+
from .metrics import Metric
|
| 20 |
+
|
| 21 |
+
from . import (
|
| 22 |
+
attack_recipes,
|
| 23 |
+
attack_results,
|
| 24 |
+
augmentation,
|
| 25 |
+
commands,
|
| 26 |
+
constraints,
|
| 27 |
+
datasets,
|
| 28 |
+
goal_function_results,
|
| 29 |
+
goal_functions,
|
| 30 |
+
loggers,
|
| 31 |
+
metrics,
|
| 32 |
+
models,
|
| 33 |
+
search_methods,
|
| 34 |
+
shared,
|
| 35 |
+
transformations,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
name = "textattack"
|
textattack/__main__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
if __name__ == "__main__":
|
| 4 |
+
import textattack
|
| 5 |
+
|
| 6 |
+
textattack.commands.textattack_cli.main()
|
textattack/attack.py
ADDED
|
@@ -0,0 +1,492 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Attack Class
|
| 3 |
+
============
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from collections import OrderedDict
|
| 7 |
+
from typing import List, Union
|
| 8 |
+
|
| 9 |
+
import lru
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
import textattack
|
| 13 |
+
from textattack.attack_results import (
|
| 14 |
+
FailedAttackResult,
|
| 15 |
+
MaximizedAttackResult,
|
| 16 |
+
SkippedAttackResult,
|
| 17 |
+
SuccessfulAttackResult,
|
| 18 |
+
)
|
| 19 |
+
from textattack.constraints import Constraint, PreTransformationConstraint
|
| 20 |
+
from textattack.goal_function_results import GoalFunctionResultStatus
|
| 21 |
+
from textattack.goal_functions import GoalFunction
|
| 22 |
+
from textattack.models.wrappers import ModelWrapper
|
| 23 |
+
from textattack.search_methods import SearchMethod
|
| 24 |
+
from textattack.shared import AttackedText, utils
|
| 25 |
+
from textattack.transformations import CompositeTransformation, Transformation
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class Attack:
|
| 29 |
+
"""An attack generates adversarial examples on text.
|
| 30 |
+
|
| 31 |
+
An attack is comprised of a goal function, constraints, transformation, and a search method. Use :meth:`attack` method to attack one sample at a time.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
goal_function (:class:`~textattack.goal_functions.GoalFunction`):
|
| 35 |
+
A function for determining how well a perturbation is doing at achieving the attack's goal.
|
| 36 |
+
constraints (list of :class:`~textattack.constraints.Constraint` or :class:`~textattack.constraints.PreTransformationConstraint`):
|
| 37 |
+
A list of constraints to add to the attack, defining which perturbations are valid.
|
| 38 |
+
transformation (:class:`~textattack.transformations.Transformation`):
|
| 39 |
+
The transformation applied at each step of the attack.
|
| 40 |
+
search_method (:class:`~textattack.search_methods.SearchMethod`):
|
| 41 |
+
The method for exploring the search space of possible perturbations
|
| 42 |
+
transformation_cache_size (:obj:`int`, `optional`, defaults to :obj:`2**15`):
|
| 43 |
+
The number of items to keep in the transformations cache
|
| 44 |
+
constraint_cache_size (:obj:`int`, `optional`, defaults to :obj:`2**15`):
|
| 45 |
+
The number of items to keep in the constraints cache
|
| 46 |
+
|
| 47 |
+
Example::
|
| 48 |
+
|
| 49 |
+
>>> import textattack
|
| 50 |
+
>>> import transformers
|
| 51 |
+
|
| 52 |
+
>>> # Load model, tokenizer, and model_wrapper
|
| 53 |
+
>>> model = transformers.AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-imdb")
|
| 54 |
+
>>> tokenizer = transformers.AutoTokenizer.from_pretrained("textattack/bert-base-uncased-imdb")
|
| 55 |
+
>>> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)
|
| 56 |
+
|
| 57 |
+
>>> # Construct our four components for `Attack`
|
| 58 |
+
>>> from textattack.constraints.pre_transformation import RepeatModification, StopwordModification
|
| 59 |
+
>>> from textattack.constraints.semantics import WordEmbeddingDistance
|
| 60 |
+
|
| 61 |
+
>>> goal_function = textattack.goal_functions.UntargetedClassification(model_wrapper)
|
| 62 |
+
>>> constraints = [
|
| 63 |
+
... RepeatModification(),
|
| 64 |
+
... StopwordModification()
|
| 65 |
+
... WordEmbeddingDistance(min_cos_sim=0.9)
|
| 66 |
+
... ]
|
| 67 |
+
>>> transformation = WordSwapEmbedding(max_candidates=50)
|
| 68 |
+
>>> search_method = GreedyWordSwapWIR(wir_method="delete")
|
| 69 |
+
|
| 70 |
+
>>> # Construct the actual attack
|
| 71 |
+
>>> attack = Attack(goal_function, constraints, transformation, search_method)
|
| 72 |
+
|
| 73 |
+
>>> input_text = "I really enjoyed the new movie that came out last month."
|
| 74 |
+
>>> label = 1 #Positive
|
| 75 |
+
>>> attack_result = attack.attack(input_text, label)
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
def __init__(
|
| 79 |
+
self,
|
| 80 |
+
goal_function: GoalFunction,
|
| 81 |
+
constraints: List[Union[Constraint, PreTransformationConstraint]],
|
| 82 |
+
transformation: Transformation,
|
| 83 |
+
search_method: SearchMethod,
|
| 84 |
+
transformation_cache_size=2**15,
|
| 85 |
+
constraint_cache_size=2**15,
|
| 86 |
+
):
|
| 87 |
+
"""Initialize an attack object.
|
| 88 |
+
|
| 89 |
+
Attacks can be run multiple times.
|
| 90 |
+
"""
|
| 91 |
+
assert isinstance(
|
| 92 |
+
goal_function, GoalFunction
|
| 93 |
+
), f"`goal_function` must be of type `textattack.goal_functions.GoalFunction`, but got type `{type(goal_function)}`."
|
| 94 |
+
assert isinstance(
|
| 95 |
+
constraints, list
|
| 96 |
+
), "`constraints` must be a list of `textattack.constraints.Constraint` or `textattack.constraints.PreTransformationConstraint`."
|
| 97 |
+
for c in constraints:
|
| 98 |
+
assert isinstance(
|
| 99 |
+
c, (Constraint, PreTransformationConstraint)
|
| 100 |
+
), "`constraints` must be a list of `textattack.constraints.Constraint` or `textattack.constraints.PreTransformationConstraint`."
|
| 101 |
+
assert isinstance(
|
| 102 |
+
transformation, Transformation
|
| 103 |
+
), f"`transformation` must be of type `textattack.transformations.Transformation`, but got type `{type(transformation)}`."
|
| 104 |
+
assert isinstance(
|
| 105 |
+
search_method, SearchMethod
|
| 106 |
+
), f"`search_method` must be of type `textattack.search_methods.SearchMethod`, but got type `{type(search_method)}`."
|
| 107 |
+
|
| 108 |
+
self.goal_function = goal_function
|
| 109 |
+
self.search_method = search_method
|
| 110 |
+
self.transformation = transformation
|
| 111 |
+
self.is_black_box = (
|
| 112 |
+
getattr(transformation, "is_black_box", True) and search_method.is_black_box
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
if not self.search_method.check_transformation_compatibility(
|
| 116 |
+
self.transformation
|
| 117 |
+
):
|
| 118 |
+
raise ValueError(
|
| 119 |
+
f"SearchMethod {self.search_method} incompatible with transformation {self.transformation}"
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
self.constraints = []
|
| 123 |
+
self.pre_transformation_constraints = []
|
| 124 |
+
for constraint in constraints:
|
| 125 |
+
if isinstance(
|
| 126 |
+
constraint,
|
| 127 |
+
textattack.constraints.PreTransformationConstraint,
|
| 128 |
+
):
|
| 129 |
+
self.pre_transformation_constraints.append(constraint)
|
| 130 |
+
else:
|
| 131 |
+
self.constraints.append(constraint)
|
| 132 |
+
|
| 133 |
+
# Check if we can use transformation cache for our transformation.
|
| 134 |
+
if not self.transformation.deterministic:
|
| 135 |
+
self.use_transformation_cache = False
|
| 136 |
+
elif isinstance(self.transformation, CompositeTransformation):
|
| 137 |
+
self.use_transformation_cache = True
|
| 138 |
+
for t in self.transformation.transformations:
|
| 139 |
+
if not t.deterministic:
|
| 140 |
+
self.use_transformation_cache = False
|
| 141 |
+
break
|
| 142 |
+
else:
|
| 143 |
+
self.use_transformation_cache = True
|
| 144 |
+
self.transformation_cache_size = transformation_cache_size
|
| 145 |
+
self.transformation_cache = lru.LRU(transformation_cache_size)
|
| 146 |
+
|
| 147 |
+
self.constraint_cache_size = constraint_cache_size
|
| 148 |
+
self.constraints_cache = lru.LRU(constraint_cache_size)
|
| 149 |
+
|
| 150 |
+
# Give search method access to functions for getting transformations and evaluating them
|
| 151 |
+
self.search_method.get_transformations = self.get_transformations
|
| 152 |
+
# Give search method access to self.goal_function for model query count, etc.
|
| 153 |
+
self.search_method.goal_function = self.goal_function
|
| 154 |
+
# The search method only needs access to the first argument. The second is only used
|
| 155 |
+
# by the attack class when checking whether to skip the sample
|
| 156 |
+
self.search_method.get_goal_results = self.goal_function.get_results
|
| 157 |
+
|
| 158 |
+
# Give search method access to get indices which need to be ordered / searched
|
| 159 |
+
self.search_method.get_indices_to_order = self.get_indices_to_order
|
| 160 |
+
|
| 161 |
+
self.search_method.filter_transformations = self.filter_transformations
|
| 162 |
+
|
| 163 |
+
def clear_cache(self, recursive=True):
|
| 164 |
+
self.constraints_cache.clear()
|
| 165 |
+
if self.use_transformation_cache:
|
| 166 |
+
self.transformation_cache.clear()
|
| 167 |
+
if recursive:
|
| 168 |
+
self.goal_function.clear_cache()
|
| 169 |
+
for constraint in self.constraints:
|
| 170 |
+
if hasattr(constraint, "clear_cache"):
|
| 171 |
+
constraint.clear_cache()
|
| 172 |
+
|
| 173 |
+
def cpu_(self):
|
| 174 |
+
"""Move any `torch.nn.Module` models that are part of Attack to CPU."""
|
| 175 |
+
visited = set()
|
| 176 |
+
|
| 177 |
+
def to_cpu(obj):
|
| 178 |
+
visited.add(id(obj))
|
| 179 |
+
if isinstance(obj, torch.nn.Module):
|
| 180 |
+
obj.cpu()
|
| 181 |
+
elif isinstance(
|
| 182 |
+
obj,
|
| 183 |
+
(
|
| 184 |
+
Attack,
|
| 185 |
+
GoalFunction,
|
| 186 |
+
Transformation,
|
| 187 |
+
SearchMethod,
|
| 188 |
+
Constraint,
|
| 189 |
+
PreTransformationConstraint,
|
| 190 |
+
ModelWrapper,
|
| 191 |
+
),
|
| 192 |
+
):
|
| 193 |
+
for key in obj.__dict__:
|
| 194 |
+
s_obj = obj.__dict__[key]
|
| 195 |
+
if id(s_obj) not in visited:
|
| 196 |
+
to_cpu(s_obj)
|
| 197 |
+
elif isinstance(obj, (list, tuple)):
|
| 198 |
+
for item in obj:
|
| 199 |
+
if id(item) not in visited and isinstance(
|
| 200 |
+
item, (Transformation, Constraint, PreTransformationConstraint)
|
| 201 |
+
):
|
| 202 |
+
to_cpu(item)
|
| 203 |
+
|
| 204 |
+
to_cpu(self)
|
| 205 |
+
|
| 206 |
+
def cuda_(self):
|
| 207 |
+
"""Move any `torch.nn.Module` models that are part of Attack to GPU."""
|
| 208 |
+
visited = set()
|
| 209 |
+
|
| 210 |
+
def to_cuda(obj):
|
| 211 |
+
visited.add(id(obj))
|
| 212 |
+
if isinstance(obj, torch.nn.Module):
|
| 213 |
+
obj.to(textattack.shared.utils.device)
|
| 214 |
+
elif isinstance(
|
| 215 |
+
obj,
|
| 216 |
+
(
|
| 217 |
+
Attack,
|
| 218 |
+
GoalFunction,
|
| 219 |
+
Transformation,
|
| 220 |
+
SearchMethod,
|
| 221 |
+
Constraint,
|
| 222 |
+
PreTransformationConstraint,
|
| 223 |
+
ModelWrapper,
|
| 224 |
+
),
|
| 225 |
+
):
|
| 226 |
+
for key in obj.__dict__:
|
| 227 |
+
s_obj = obj.__dict__[key]
|
| 228 |
+
if id(s_obj) not in visited:
|
| 229 |
+
to_cuda(s_obj)
|
| 230 |
+
elif isinstance(obj, (list, tuple)):
|
| 231 |
+
for item in obj:
|
| 232 |
+
if id(item) not in visited and isinstance(
|
| 233 |
+
item, (Transformation, Constraint, PreTransformationConstraint)
|
| 234 |
+
):
|
| 235 |
+
to_cuda(item)
|
| 236 |
+
|
| 237 |
+
to_cuda(self)
|
| 238 |
+
|
| 239 |
+
def get_indices_to_order(self, current_text, **kwargs):
|
| 240 |
+
"""Applies ``pre_transformation_constraints`` to ``text`` to get all
|
| 241 |
+
the indices that can be used to search and order.
|
| 242 |
+
|
| 243 |
+
Args:
|
| 244 |
+
current_text: The current ``AttackedText`` for which we need to find indices are eligible to be ordered.
|
| 245 |
+
Returns:
|
| 246 |
+
The length and the filtered list of indices which search methods can use to search/order.
|
| 247 |
+
"""
|
| 248 |
+
|
| 249 |
+
indices_to_order = self.transformation(
|
| 250 |
+
current_text,
|
| 251 |
+
pre_transformation_constraints=self.pre_transformation_constraints,
|
| 252 |
+
return_indices=True,
|
| 253 |
+
**kwargs,
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
len_text = len(indices_to_order)
|
| 257 |
+
|
| 258 |
+
# Convert indices_to_order to list for easier shuffling later
|
| 259 |
+
return len_text, list(indices_to_order)
|
| 260 |
+
|
| 261 |
+
def _get_transformations_uncached(self, current_text, original_text=None, **kwargs):
|
| 262 |
+
"""Applies ``self.transformation`` to ``text``, then filters the list
|
| 263 |
+
of possible transformations through the applicable constraints.
|
| 264 |
+
|
| 265 |
+
Args:
|
| 266 |
+
current_text: The current ``AttackedText`` on which to perform the transformations.
|
| 267 |
+
original_text: The original ``AttackedText`` from which the attack started.
|
| 268 |
+
Returns:
|
| 269 |
+
A filtered list of transformations where each transformation matches the constraints
|
| 270 |
+
"""
|
| 271 |
+
transformed_texts = self.transformation(
|
| 272 |
+
current_text,
|
| 273 |
+
pre_transformation_constraints=self.pre_transformation_constraints,
|
| 274 |
+
**kwargs,
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
return transformed_texts
|
| 278 |
+
|
| 279 |
+
def get_transformations(self, current_text, original_text=None, **kwargs):
|
| 280 |
+
"""Applies ``self.transformation`` to ``text``, then filters the list
|
| 281 |
+
of possible transformations through the applicable constraints.
|
| 282 |
+
|
| 283 |
+
Args:
|
| 284 |
+
current_text: The current ``AttackedText`` on which to perform the transformations.
|
| 285 |
+
original_text: The original ``AttackedText`` from which the attack started.
|
| 286 |
+
Returns:
|
| 287 |
+
A filtered list of transformations where each transformation matches the constraints
|
| 288 |
+
"""
|
| 289 |
+
if not self.transformation:
|
| 290 |
+
raise RuntimeError(
|
| 291 |
+
"Cannot call `get_transformations` without a transformation."
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
if self.use_transformation_cache:
|
| 295 |
+
cache_key = tuple([current_text] + sorted(kwargs.items()))
|
| 296 |
+
if utils.hashable(cache_key) and cache_key in self.transformation_cache:
|
| 297 |
+
# promote transformed_text to the top of the LRU cache
|
| 298 |
+
self.transformation_cache[cache_key] = self.transformation_cache[
|
| 299 |
+
cache_key
|
| 300 |
+
]
|
| 301 |
+
transformed_texts = list(self.transformation_cache[cache_key])
|
| 302 |
+
else:
|
| 303 |
+
transformed_texts = self._get_transformations_uncached(
|
| 304 |
+
current_text, original_text, **kwargs
|
| 305 |
+
)
|
| 306 |
+
if utils.hashable(cache_key):
|
| 307 |
+
self.transformation_cache[cache_key] = tuple(transformed_texts)
|
| 308 |
+
else:
|
| 309 |
+
transformed_texts = self._get_transformations_uncached(
|
| 310 |
+
current_text, original_text, **kwargs
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
return self.filter_transformations(
|
| 314 |
+
transformed_texts, current_text, original_text
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
def _filter_transformations_uncached(
|
| 318 |
+
self, transformed_texts, current_text, original_text=None
|
| 319 |
+
):
|
| 320 |
+
"""Filters a list of potential transformed texts based on
|
| 321 |
+
``self.constraints``
|
| 322 |
+
|
| 323 |
+
Args:
|
| 324 |
+
transformed_texts: A list of candidate transformed ``AttackedText`` to filter.
|
| 325 |
+
current_text: The current ``AttackedText`` on which the transformation was applied.
|
| 326 |
+
original_text: The original ``AttackedText`` from which the attack started.
|
| 327 |
+
"""
|
| 328 |
+
filtered_texts = transformed_texts[:]
|
| 329 |
+
for C in self.constraints:
|
| 330 |
+
if len(filtered_texts) == 0:
|
| 331 |
+
break
|
| 332 |
+
if C.compare_against_original:
|
| 333 |
+
if not original_text:
|
| 334 |
+
raise ValueError(
|
| 335 |
+
f"Missing `original_text` argument when constraint {type(C)} is set to compare against `original_text`"
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
filtered_texts = C.call_many(filtered_texts, original_text)
|
| 339 |
+
else:
|
| 340 |
+
filtered_texts = C.call_many(filtered_texts, current_text)
|
| 341 |
+
# Default to false for all original transformations.
|
| 342 |
+
for original_transformed_text in transformed_texts:
|
| 343 |
+
self.constraints_cache[(current_text, original_transformed_text)] = False
|
| 344 |
+
# Set unfiltered transformations to True in the cache.
|
| 345 |
+
for filtered_text in filtered_texts:
|
| 346 |
+
self.constraints_cache[(current_text, filtered_text)] = True
|
| 347 |
+
return filtered_texts
|
| 348 |
+
|
| 349 |
+
def filter_transformations(
|
| 350 |
+
self, transformed_texts, current_text, original_text=None
|
| 351 |
+
):
|
| 352 |
+
"""Filters a list of potential transformed texts based on
|
| 353 |
+
``self.constraints`` Utilizes an LRU cache to attempt to avoid
|
| 354 |
+
recomputing common transformations.
|
| 355 |
+
|
| 356 |
+
Args:
|
| 357 |
+
transformed_texts: A list of candidate transformed ``AttackedText`` to filter.
|
| 358 |
+
current_text: The current ``AttackedText`` on which the transformation was applied.
|
| 359 |
+
original_text: The original ``AttackedText`` from which the attack started.
|
| 360 |
+
"""
|
| 361 |
+
# Remove any occurences of current_text in transformed_texts
|
| 362 |
+
transformed_texts = [
|
| 363 |
+
t for t in transformed_texts if t.text != current_text.text
|
| 364 |
+
]
|
| 365 |
+
# Populate cache with transformed_texts
|
| 366 |
+
uncached_texts = []
|
| 367 |
+
filtered_texts = []
|
| 368 |
+
for transformed_text in transformed_texts:
|
| 369 |
+
if (current_text, transformed_text) not in self.constraints_cache:
|
| 370 |
+
uncached_texts.append(transformed_text)
|
| 371 |
+
else:
|
| 372 |
+
# promote transformed_text to the top of the LRU cache
|
| 373 |
+
self.constraints_cache[
|
| 374 |
+
(current_text, transformed_text)
|
| 375 |
+
] = self.constraints_cache[(current_text, transformed_text)]
|
| 376 |
+
if self.constraints_cache[(current_text, transformed_text)]:
|
| 377 |
+
filtered_texts.append(transformed_text)
|
| 378 |
+
filtered_texts += self._filter_transformations_uncached(
|
| 379 |
+
uncached_texts, current_text, original_text=original_text
|
| 380 |
+
)
|
| 381 |
+
# Sort transformations to ensure order is preserved between runs
|
| 382 |
+
filtered_texts.sort(key=lambda t: t.text)
|
| 383 |
+
return filtered_texts
|
| 384 |
+
|
| 385 |
+
def _attack(self, initial_result):
|
| 386 |
+
"""Calls the ``SearchMethod`` to perturb the ``AttackedText`` stored in
|
| 387 |
+
``initial_result``.
|
| 388 |
+
|
| 389 |
+
Args:
|
| 390 |
+
initial_result: The initial ``GoalFunctionResult`` from which to perturb.
|
| 391 |
+
|
| 392 |
+
Returns:
|
| 393 |
+
A ``SuccessfulAttackResult``, ``FailedAttackResult``,
|
| 394 |
+
or ``MaximizedAttackResult``.
|
| 395 |
+
"""
|
| 396 |
+
final_result = self.search_method(initial_result)
|
| 397 |
+
self.clear_cache()
|
| 398 |
+
if final_result.goal_status == GoalFunctionResultStatus.SUCCEEDED:
|
| 399 |
+
result = SuccessfulAttackResult(
|
| 400 |
+
initial_result,
|
| 401 |
+
final_result,
|
| 402 |
+
)
|
| 403 |
+
elif final_result.goal_status == GoalFunctionResultStatus.SEARCHING:
|
| 404 |
+
result = FailedAttackResult(
|
| 405 |
+
initial_result,
|
| 406 |
+
final_result,
|
| 407 |
+
)
|
| 408 |
+
elif final_result.goal_status == GoalFunctionResultStatus.MAXIMIZING:
|
| 409 |
+
result = MaximizedAttackResult(
|
| 410 |
+
initial_result,
|
| 411 |
+
final_result,
|
| 412 |
+
)
|
| 413 |
+
else:
|
| 414 |
+
raise ValueError(f"Unrecognized goal status {final_result.goal_status}")
|
| 415 |
+
return result
|
| 416 |
+
|
| 417 |
+
def attack(self, example, ground_truth_output):
|
| 418 |
+
"""Attack a single example.
|
| 419 |
+
|
| 420 |
+
Args:
|
| 421 |
+
example (:obj:`str`, :obj:`OrderedDict[str, str]` or :class:`~textattack.shared.AttackedText`):
|
| 422 |
+
Example to attack. It can be a single string or an `OrderedDict` where
|
| 423 |
+
keys represent the input fields (e.g. "premise", "hypothesis") and the values are the actual input textx.
|
| 424 |
+
Also accepts :class:`~textattack.shared.AttackedText` that wraps around the input.
|
| 425 |
+
ground_truth_output(:obj:`int`, :obj:`float` or :obj:`str`):
|
| 426 |
+
Ground truth output of `example`.
|
| 427 |
+
For classification tasks, it should be an integer representing the ground truth label.
|
| 428 |
+
For regression tasks (e.g. STS), it should be the target value.
|
| 429 |
+
For seq2seq tasks (e.g. translation), it should be the target string.
|
| 430 |
+
Returns:
|
| 431 |
+
:class:`~textattack.attack_results.AttackResult` that represents the result of the attack.
|
| 432 |
+
"""
|
| 433 |
+
assert isinstance(
|
| 434 |
+
example, (str, OrderedDict, AttackedText)
|
| 435 |
+
), "`example` must either be `str`, `collections.OrderedDict`, `textattack.shared.AttackedText`."
|
| 436 |
+
if isinstance(example, (str, OrderedDict)):
|
| 437 |
+
example = AttackedText(example)
|
| 438 |
+
|
| 439 |
+
assert isinstance(
|
| 440 |
+
ground_truth_output, (int, str)
|
| 441 |
+
), "`ground_truth_output` must either be `str` or `int`."
|
| 442 |
+
goal_function_result, _ = self.goal_function.init_attack_example(
|
| 443 |
+
example, ground_truth_output
|
| 444 |
+
)
|
| 445 |
+
if goal_function_result.goal_status == GoalFunctionResultStatus.SKIPPED:
|
| 446 |
+
return SkippedAttackResult(goal_function_result)
|
| 447 |
+
else:
|
| 448 |
+
result = self._attack(goal_function_result)
|
| 449 |
+
return result
|
| 450 |
+
|
| 451 |
+
def __repr__(self):
|
| 452 |
+
"""Prints attack parameters in a human-readable string.
|
| 453 |
+
|
| 454 |
+
Inspired by the readability of printing PyTorch nn.Modules:
|
| 455 |
+
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py
|
| 456 |
+
"""
|
| 457 |
+
main_str = "Attack" + "("
|
| 458 |
+
lines = []
|
| 459 |
+
|
| 460 |
+
lines.append(utils.add_indent(f"(search_method): {self.search_method}", 2))
|
| 461 |
+
# self.goal_function
|
| 462 |
+
lines.append(utils.add_indent(f"(goal_function): {self.goal_function}", 2))
|
| 463 |
+
# self.transformation
|
| 464 |
+
lines.append(utils.add_indent(f"(transformation): {self.transformation}", 2))
|
| 465 |
+
# self.constraints
|
| 466 |
+
constraints_lines = []
|
| 467 |
+
constraints = self.constraints + self.pre_transformation_constraints
|
| 468 |
+
if len(constraints):
|
| 469 |
+
for i, constraint in enumerate(constraints):
|
| 470 |
+
constraints_lines.append(utils.add_indent(f"({i}): {constraint}", 2))
|
| 471 |
+
constraints_str = utils.add_indent("\n" + "\n".join(constraints_lines), 2)
|
| 472 |
+
else:
|
| 473 |
+
constraints_str = "None"
|
| 474 |
+
lines.append(utils.add_indent(f"(constraints): {constraints_str}", 2))
|
| 475 |
+
# self.is_black_box
|
| 476 |
+
lines.append(utils.add_indent(f"(is_black_box): {self.is_black_box}", 2))
|
| 477 |
+
main_str += "\n " + "\n ".join(lines) + "\n"
|
| 478 |
+
main_str += ")"
|
| 479 |
+
return main_str
|
| 480 |
+
|
| 481 |
+
def __getstate__(self):
|
| 482 |
+
state = self.__dict__.copy()
|
| 483 |
+
state["transformation_cache"] = None
|
| 484 |
+
state["constraints_cache"] = None
|
| 485 |
+
return state
|
| 486 |
+
|
| 487 |
+
def __setstate__(self, state):
|
| 488 |
+
self.__dict__ = state
|
| 489 |
+
self.transformation_cache = lru.LRU(self.transformation_cache_size)
|
| 490 |
+
self.constraints_cache = lru.LRU(self.constraint_cache_size)
|
| 491 |
+
|
| 492 |
+
__str__ = __repr__
|
textattack/attack_args.py
ADDED
|
@@ -0,0 +1,763 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AttackArgs Class
|
| 3 |
+
================
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from dataclasses import dataclass, field
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
import time
|
| 11 |
+
from typing import Dict, Optional
|
| 12 |
+
|
| 13 |
+
import textattack
|
| 14 |
+
from textattack.shared.utils import ARGS_SPLIT_TOKEN, load_module_from_file
|
| 15 |
+
|
| 16 |
+
from .attack import Attack
|
| 17 |
+
from .dataset_args import DatasetArgs
|
| 18 |
+
from .model_args import ModelArgs
|
| 19 |
+
|
| 20 |
+
ATTACK_RECIPE_NAMES = {
|
| 21 |
+
"alzantot": "textattack.attack_recipes.GeneticAlgorithmAlzantot2018",
|
| 22 |
+
"bae": "textattack.attack_recipes.BAEGarg2019",
|
| 23 |
+
"bert-attack": "textattack.attack_recipes.BERTAttackLi2020",
|
| 24 |
+
"faster-alzantot": "textattack.attack_recipes.FasterGeneticAlgorithmJia2019",
|
| 25 |
+
"deepwordbug": "textattack.attack_recipes.DeepWordBugGao2018",
|
| 26 |
+
"hotflip": "textattack.attack_recipes.HotFlipEbrahimi2017",
|
| 27 |
+
"input-reduction": "textattack.attack_recipes.InputReductionFeng2018",
|
| 28 |
+
"kuleshov": "textattack.attack_recipes.Kuleshov2017",
|
| 29 |
+
"morpheus": "textattack.attack_recipes.MorpheusTan2020",
|
| 30 |
+
"seq2sick": "textattack.attack_recipes.Seq2SickCheng2018BlackBox",
|
| 31 |
+
"textbugger": "textattack.attack_recipes.TextBuggerLi2018",
|
| 32 |
+
"textfooler": "textattack.attack_recipes.TextFoolerJin2019",
|
| 33 |
+
"pwws": "textattack.attack_recipes.PWWSRen2019",
|
| 34 |
+
"iga": "textattack.attack_recipes.IGAWang2019",
|
| 35 |
+
"pruthi": "textattack.attack_recipes.Pruthi2019",
|
| 36 |
+
"pso": "textattack.attack_recipes.PSOZang2020",
|
| 37 |
+
"checklist": "textattack.attack_recipes.CheckList2020",
|
| 38 |
+
"clare": "textattack.attack_recipes.CLARE2020",
|
| 39 |
+
"a2t": "textattack.attack_recipes.A2TYoo2021",
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
BLACK_BOX_TRANSFORMATION_CLASS_NAMES = {
|
| 44 |
+
"random-synonym-insertion": "textattack.transformations.RandomSynonymInsertion",
|
| 45 |
+
"word-deletion": "textattack.transformations.WordDeletion",
|
| 46 |
+
"word-swap-embedding": "textattack.transformations.WordSwapEmbedding",
|
| 47 |
+
"word-swap-homoglyph": "textattack.transformations.WordSwapHomoglyphSwap",
|
| 48 |
+
"word-swap-inflections": "textattack.transformations.WordSwapInflections",
|
| 49 |
+
"word-swap-neighboring-char-swap": "textattack.transformations.WordSwapNeighboringCharacterSwap",
|
| 50 |
+
"word-swap-random-char-deletion": "textattack.transformations.WordSwapRandomCharacterDeletion",
|
| 51 |
+
"word-swap-random-char-insertion": "textattack.transformations.WordSwapRandomCharacterInsertion",
|
| 52 |
+
"word-swap-random-char-substitution": "textattack.transformations.WordSwapRandomCharacterSubstitution",
|
| 53 |
+
"word-swap-wordnet": "textattack.transformations.WordSwapWordNet",
|
| 54 |
+
"word-swap-masked-lm": "textattack.transformations.WordSwapMaskedLM",
|
| 55 |
+
"word-swap-hownet": "textattack.transformations.WordSwapHowNet",
|
| 56 |
+
"word-swap-qwerty": "textattack.transformations.WordSwapQWERTY",
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
WHITE_BOX_TRANSFORMATION_CLASS_NAMES = {
|
| 61 |
+
"word-swap-gradient": "textattack.transformations.WordSwapGradientBased"
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
CONSTRAINT_CLASS_NAMES = {
|
| 66 |
+
#
|
| 67 |
+
# Semantics constraints
|
| 68 |
+
#
|
| 69 |
+
"embedding": "textattack.constraints.semantics.WordEmbeddingDistance",
|
| 70 |
+
"bert": "textattack.constraints.semantics.sentence_encoders.BERT",
|
| 71 |
+
"infer-sent": "textattack.constraints.semantics.sentence_encoders.InferSent",
|
| 72 |
+
"thought-vector": "textattack.constraints.semantics.sentence_encoders.ThoughtVector",
|
| 73 |
+
"use": "textattack.constraints.semantics.sentence_encoders.UniversalSentenceEncoder",
|
| 74 |
+
"muse": "textattack.constraints.semantics.sentence_encoders.MultilingualUniversalSentenceEncoder",
|
| 75 |
+
"bert-score": "textattack.constraints.semantics.BERTScore",
|
| 76 |
+
#
|
| 77 |
+
# Grammaticality constraints
|
| 78 |
+
#
|
| 79 |
+
"lang-tool": "textattack.constraints.grammaticality.LanguageTool",
|
| 80 |
+
"part-of-speech": "textattack.constraints.grammaticality.PartOfSpeech",
|
| 81 |
+
"goog-lm": "textattack.constraints.grammaticality.language_models.GoogleLanguageModel",
|
| 82 |
+
"gpt2": "textattack.constraints.grammaticality.language_models.GPT2",
|
| 83 |
+
"learning-to-write": "textattack.constraints.grammaticality.language_models.LearningToWriteLanguageModel",
|
| 84 |
+
"cola": "textattack.constraints.grammaticality.COLA",
|
| 85 |
+
#
|
| 86 |
+
# Overlap constraints
|
| 87 |
+
#
|
| 88 |
+
"bleu": "textattack.constraints.overlap.BLEU",
|
| 89 |
+
"chrf": "textattack.constraints.overlap.chrF",
|
| 90 |
+
"edit-distance": "textattack.constraints.overlap.LevenshteinEditDistance",
|
| 91 |
+
"meteor": "textattack.constraints.overlap.METEOR",
|
| 92 |
+
"max-words-perturbed": "textattack.constraints.overlap.MaxWordsPerturbed",
|
| 93 |
+
#
|
| 94 |
+
# Pre-transformation constraints
|
| 95 |
+
#
|
| 96 |
+
"repeat": "textattack.constraints.pre_transformation.RepeatModification",
|
| 97 |
+
"stopword": "textattack.constraints.pre_transformation.StopwordModification",
|
| 98 |
+
"max-word-index": "textattack.constraints.pre_transformation.MaxWordIndexModification",
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
SEARCH_METHOD_CLASS_NAMES = {
|
| 103 |
+
"beam-search": "textattack.search_methods.BeamSearch",
|
| 104 |
+
"greedy": "textattack.search_methods.GreedySearch",
|
| 105 |
+
"ga-word": "textattack.search_methods.GeneticAlgorithm",
|
| 106 |
+
"greedy-word-wir": "textattack.search_methods.GreedyWordSwapWIR",
|
| 107 |
+
"pso": "textattack.search_methods.ParticleSwarmOptimization",
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
GOAL_FUNCTION_CLASS_NAMES = {
|
| 112 |
+
#
|
| 113 |
+
# Classification goal functions
|
| 114 |
+
#
|
| 115 |
+
"targeted-classification": "textattack.goal_functions.classification.TargetedClassification",
|
| 116 |
+
"untargeted-classification": "textattack.goal_functions.classification.UntargetedClassification",
|
| 117 |
+
"input-reduction": "textattack.goal_functions.classification.InputReduction",
|
| 118 |
+
#
|
| 119 |
+
# Text goal functions
|
| 120 |
+
#
|
| 121 |
+
"minimize-bleu": "textattack.goal_functions.text.MinimizeBleu",
|
| 122 |
+
"non-overlapping-output": "textattack.goal_functions.text.NonOverlappingOutput",
|
| 123 |
+
"text-to-text": "textattack.goal_functions.text.TextToTextGoalFunction",
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
@dataclass
|
| 128 |
+
class AttackArgs:
|
| 129 |
+
"""Attack arguments to be passed to :class:`~textattack.Attacker`.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
num_examples (:obj:`int`, 'optional`, defaults to :obj:`10`):
|
| 133 |
+
The number of examples to attack. :obj:`-1` for entire dataset.
|
| 134 |
+
num_successful_examples (:obj:`int`, `optional`, defaults to :obj:`None`):
|
| 135 |
+
The number of successful adversarial examples we want. This is different from :obj:`num_examples`
|
| 136 |
+
as :obj:`num_examples` only cares about attacking `N` samples while :obj:`num_successful_examples` aims to keep attacking
|
| 137 |
+
until we have `N` successful cases.
|
| 138 |
+
|
| 139 |
+
.. note::
|
| 140 |
+
If set, this argument overrides `num_examples` argument.
|
| 141 |
+
num_examples_offset (:obj: `int`, `optional`, defaults to :obj:`0`):
|
| 142 |
+
The offset index to start at in the dataset.
|
| 143 |
+
attack_n (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 144 |
+
Whether to run attack until total of `N` examples have been attacked (and not skipped).
|
| 145 |
+
shuffle (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 146 |
+
If :obj:`True`, we randomly shuffle the dataset before attacking. However, this avoids actually shuffling
|
| 147 |
+
the dataset internally and opts for shuffling the list of indices of examples we want to attack. This means
|
| 148 |
+
:obj:`shuffle` can now be used with checkpoint saving.
|
| 149 |
+
query_budget (:obj:`int`, `optional`, defaults to :obj:`None`):
|
| 150 |
+
The maximum number of model queries allowed per example attacked.
|
| 151 |
+
If not set, we use the query budget set in the :class:`~textattack.goal_functions.GoalFunction` object (which by default is :obj:`float("inf")`).
|
| 152 |
+
|
| 153 |
+
.. note::
|
| 154 |
+
Setting this overwrites the query budget set in :class:`~textattack.goal_functions.GoalFunction` object.
|
| 155 |
+
checkpoint_interval (:obj:`int`, `optional`, defaults to :obj:`None`):
|
| 156 |
+
If set, checkpoint will be saved after attacking every `N` examples. If :obj:`None` is passed, no checkpoints will be saved.
|
| 157 |
+
checkpoint_dir (:obj:`str`, `optional`, defaults to :obj:`"checkpoints"`):
|
| 158 |
+
The directory to save checkpoint files.
|
| 159 |
+
random_seed (:obj:`int`, `optional`, defaults to :obj:`765`):
|
| 160 |
+
Random seed for reproducibility.
|
| 161 |
+
parallel (:obj:`False`, `optional`, defaults to :obj:`False`):
|
| 162 |
+
If :obj:`True`, run attack using multiple CPUs/GPUs.
|
| 163 |
+
num_workers_per_device (:obj:`int`, `optional`, defaults to :obj:`1`):
|
| 164 |
+
Number of worker processes to run per device in parallel mode (i.e. :obj:`parallel=True`). For example, if you are using GPUs and :obj:`num_workers_per_device=2`,
|
| 165 |
+
then 2 processes will be running in each GPU.
|
| 166 |
+
log_to_txt (:obj:`str`, `optional`, defaults to :obj:`None`):
|
| 167 |
+
If set, save attack logs as a `.txt` file to the directory specified by this argument.
|
| 168 |
+
If the last part of the provided path ends with `.txt` extension, it is assumed to the desired path of the log file.
|
| 169 |
+
log_to_csv (:obj:`str`, `optional`, defaults to :obj:`None`):
|
| 170 |
+
If set, save attack logs as a CSV file to the directory specified by this argument.
|
| 171 |
+
If the last part of the provided path ends with `.csv` extension, it is assumed to the desired path of the log file.
|
| 172 |
+
csv_coloring_style (:obj:`str`, `optional`, defaults to :obj:`"file"`):
|
| 173 |
+
Method for choosing how to mark perturbed parts of the text. Options are :obj:`"file"`, :obj:`"plain"`, and :obj:`"html"`.
|
| 174 |
+
:obj:`"file"` wraps perturbed parts with double brackets :obj:`[[ <text> ]]` while :obj:`"plain"` does not mark the text in any way.
|
| 175 |
+
log_to_visdom (:obj:`dict`, `optional`, defaults to :obj:`None`):
|
| 176 |
+
If set, Visdom logger is used with the provided dictionary passed as a keyword arguments to :class:`~textattack.loggers.VisdomLogger`.
|
| 177 |
+
Pass in empty dictionary to use default arguments. For custom logger, the dictionary should have the following
|
| 178 |
+
three keys and their corresponding values: :obj:`"env", "port", "hostname"`.
|
| 179 |
+
log_to_wandb(:obj:`dict`, `optional`, defaults to :obj:`None`):
|
| 180 |
+
If set, WandB logger is used with the provided dictionary passed as a keyword arguments to :class:`~textattack.loggers.WeightsAndBiasesLogger`.
|
| 181 |
+
Pass in empty dictionary to use default arguments. For custom logger, the dictionary should have the following
|
| 182 |
+
key and its corresponding value: :obj:`"project"`.
|
| 183 |
+
disable_stdout (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 184 |
+
Disable displaying individual attack results to stdout.
|
| 185 |
+
silent (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 186 |
+
Disable all logging (except for errors). This is stronger than :obj:`disable_stdout`.
|
| 187 |
+
enable_advance_metrics (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 188 |
+
Enable calculation and display of optional advance post-hoc metrics like perplexity, grammar errors, etc.
|
| 189 |
+
"""
|
| 190 |
+
|
| 191 |
+
num_examples: int = 10
|
| 192 |
+
num_successful_examples: int = None
|
| 193 |
+
num_examples_offset: int = 0
|
| 194 |
+
attack_n: bool = False
|
| 195 |
+
shuffle: bool = False
|
| 196 |
+
query_budget: int = None
|
| 197 |
+
checkpoint_interval: int = None
|
| 198 |
+
checkpoint_dir: str = "checkpoints"
|
| 199 |
+
random_seed: int = 765 # equivalent to sum((ord(c) for c in "TEXTATTACK"))
|
| 200 |
+
parallel: bool = False
|
| 201 |
+
num_workers_per_device: int = 1
|
| 202 |
+
log_to_txt: str = None
|
| 203 |
+
log_to_csv: str = None
|
| 204 |
+
log_summary_to_json: str = None
|
| 205 |
+
csv_coloring_style: str = "file"
|
| 206 |
+
log_to_visdom: dict = None
|
| 207 |
+
log_to_wandb: dict = None
|
| 208 |
+
disable_stdout: bool = False
|
| 209 |
+
silent: bool = False
|
| 210 |
+
enable_advance_metrics: bool = False
|
| 211 |
+
metrics: Optional[Dict] = None
|
| 212 |
+
|
| 213 |
+
def __post_init__(self):
|
| 214 |
+
if self.num_successful_examples:
|
| 215 |
+
self.num_examples = None
|
| 216 |
+
if self.num_examples:
|
| 217 |
+
assert (
|
| 218 |
+
self.num_examples >= 0 or self.num_examples == -1
|
| 219 |
+
), "`num_examples` must be greater than or equal to 0 or equal to -1."
|
| 220 |
+
if self.num_successful_examples:
|
| 221 |
+
assert (
|
| 222 |
+
self.num_successful_examples >= 0
|
| 223 |
+
), "`num_examples` must be greater than or equal to 0."
|
| 224 |
+
|
| 225 |
+
if self.query_budget:
|
| 226 |
+
assert self.query_budget > 0, "`query_budget` must be greater than 0."
|
| 227 |
+
|
| 228 |
+
if self.checkpoint_interval:
|
| 229 |
+
assert (
|
| 230 |
+
self.checkpoint_interval > 0
|
| 231 |
+
), "`checkpoint_interval` must be greater than 0."
|
| 232 |
+
|
| 233 |
+
assert (
|
| 234 |
+
self.num_workers_per_device > 0
|
| 235 |
+
), "`num_workers_per_device` must be greater than 0."
|
| 236 |
+
|
| 237 |
+
@classmethod
|
| 238 |
+
def _add_parser_args(cls, parser):
|
| 239 |
+
"""Add listed args to command line parser."""
|
| 240 |
+
default_obj = cls()
|
| 241 |
+
num_ex_group = parser.add_mutually_exclusive_group(required=False)
|
| 242 |
+
num_ex_group.add_argument(
|
| 243 |
+
"--num-examples",
|
| 244 |
+
"-n",
|
| 245 |
+
type=int,
|
| 246 |
+
default=default_obj.num_examples,
|
| 247 |
+
help="The number of examples to process, -1 for entire dataset.",
|
| 248 |
+
)
|
| 249 |
+
num_ex_group.add_argument(
|
| 250 |
+
"--num-successful-examples",
|
| 251 |
+
type=int,
|
| 252 |
+
default=default_obj.num_successful_examples,
|
| 253 |
+
help="The number of successful adversarial examples we want.",
|
| 254 |
+
)
|
| 255 |
+
parser.add_argument(
|
| 256 |
+
"--num-examples-offset",
|
| 257 |
+
"-o",
|
| 258 |
+
type=int,
|
| 259 |
+
required=False,
|
| 260 |
+
default=default_obj.num_examples_offset,
|
| 261 |
+
help="The offset to start at in the dataset.",
|
| 262 |
+
)
|
| 263 |
+
parser.add_argument(
|
| 264 |
+
"--query-budget",
|
| 265 |
+
"-q",
|
| 266 |
+
type=int,
|
| 267 |
+
default=default_obj.query_budget,
|
| 268 |
+
help="The maximum number of model queries allowed per example attacked. Setting this overwrites the query budget set in `GoalFunction` object.",
|
| 269 |
+
)
|
| 270 |
+
parser.add_argument(
|
| 271 |
+
"--shuffle",
|
| 272 |
+
action="store_true",
|
| 273 |
+
default=default_obj.shuffle,
|
| 274 |
+
help="If `True`, shuffle the samples before we attack the dataset. Default is False.",
|
| 275 |
+
)
|
| 276 |
+
parser.add_argument(
|
| 277 |
+
"--attack-n",
|
| 278 |
+
action="store_true",
|
| 279 |
+
default=default_obj.attack_n,
|
| 280 |
+
help="Whether to run attack until `n` examples have been attacked (not skipped).",
|
| 281 |
+
)
|
| 282 |
+
parser.add_argument(
|
| 283 |
+
"--checkpoint-dir",
|
| 284 |
+
required=False,
|
| 285 |
+
type=str,
|
| 286 |
+
default=default_obj.checkpoint_dir,
|
| 287 |
+
help="The directory to save checkpoint files.",
|
| 288 |
+
)
|
| 289 |
+
parser.add_argument(
|
| 290 |
+
"--checkpoint-interval",
|
| 291 |
+
required=False,
|
| 292 |
+
type=int,
|
| 293 |
+
default=default_obj.checkpoint_interval,
|
| 294 |
+
help="If set, checkpoint will be saved after attacking every N examples. If not set, no checkpoints will be saved.",
|
| 295 |
+
)
|
| 296 |
+
parser.add_argument(
|
| 297 |
+
"--random-seed",
|
| 298 |
+
default=default_obj.random_seed,
|
| 299 |
+
type=int,
|
| 300 |
+
help="Random seed for reproducibility.",
|
| 301 |
+
)
|
| 302 |
+
parser.add_argument(
|
| 303 |
+
"--parallel",
|
| 304 |
+
action="store_true",
|
| 305 |
+
default=default_obj.parallel,
|
| 306 |
+
help="Run attack using multiple GPUs.",
|
| 307 |
+
)
|
| 308 |
+
parser.add_argument(
|
| 309 |
+
"--num-workers-per-device",
|
| 310 |
+
default=default_obj.num_workers_per_device,
|
| 311 |
+
type=int,
|
| 312 |
+
help="Number of worker processes to run per device.",
|
| 313 |
+
)
|
| 314 |
+
parser.add_argument(
|
| 315 |
+
"--log-to-txt",
|
| 316 |
+
nargs="?",
|
| 317 |
+
default=default_obj.log_to_txt,
|
| 318 |
+
const="",
|
| 319 |
+
type=str,
|
| 320 |
+
help="Path to which to save attack logs as a text file. Set this argument if you want to save text logs. "
|
| 321 |
+
"If the last part of the path ends with `.txt` extension, the path is assumed to path for output file.",
|
| 322 |
+
)
|
| 323 |
+
parser.add_argument(
|
| 324 |
+
"--log-to-csv",
|
| 325 |
+
nargs="?",
|
| 326 |
+
default=default_obj.log_to_csv,
|
| 327 |
+
const="",
|
| 328 |
+
type=str,
|
| 329 |
+
help="Path to which to save attack logs as a CSV file. Set this argument if you want to save CSV logs. "
|
| 330 |
+
"If the last part of the path ends with `.csv` extension, the path is assumed to path for output file.",
|
| 331 |
+
)
|
| 332 |
+
parser.add_argument(
|
| 333 |
+
"--log-summary-to-json",
|
| 334 |
+
nargs="?",
|
| 335 |
+
default=default_obj.log_summary_to_json,
|
| 336 |
+
const="",
|
| 337 |
+
type=str,
|
| 338 |
+
help="Path to which to save attack summary as a JSON file. Set this argument if you want to save attack results summary in a JSON. "
|
| 339 |
+
"If the last part of the path ends with `.json` extension, the path is assumed to path for output file.",
|
| 340 |
+
)
|
| 341 |
+
parser.add_argument(
|
| 342 |
+
"--csv-coloring-style",
|
| 343 |
+
default=default_obj.csv_coloring_style,
|
| 344 |
+
type=str,
|
| 345 |
+
help='Method for choosing how to mark perturbed parts of the text in CSV logs. Options are "file" and "plain". '
|
| 346 |
+
'"file" wraps text with double brackets `[[ <text> ]]` while "plain" does not mark any text. Default is "file".',
|
| 347 |
+
)
|
| 348 |
+
parser.add_argument(
|
| 349 |
+
"--log-to-visdom",
|
| 350 |
+
nargs="?",
|
| 351 |
+
default=None,
|
| 352 |
+
const='{"env": "main", "port": 8097, "hostname": "localhost"}',
|
| 353 |
+
type=json.loads,
|
| 354 |
+
help="Set this argument if you want to log attacks to Visdom. The dictionary should have the following "
|
| 355 |
+
'three keys and their corresponding values: `"env", "port", "hostname"`. '
|
| 356 |
+
'Example for command line use: `--log-to-visdom {"env": "main", "port": 8097, "hostname": "localhost"}`.',
|
| 357 |
+
)
|
| 358 |
+
parser.add_argument(
|
| 359 |
+
"--log-to-wandb",
|
| 360 |
+
nargs="?",
|
| 361 |
+
default=None,
|
| 362 |
+
const='{"project": "textattack"}',
|
| 363 |
+
type=json.loads,
|
| 364 |
+
help="Set this argument if you want to log attacks to WandB. The dictionary should have the following "
|
| 365 |
+
'key and its corresponding value: `"project"`. '
|
| 366 |
+
'Example for command line use: `--log-to-wandb {"project": "textattack"}`.',
|
| 367 |
+
)
|
| 368 |
+
parser.add_argument(
|
| 369 |
+
"--disable-stdout",
|
| 370 |
+
action="store_true",
|
| 371 |
+
default=default_obj.disable_stdout,
|
| 372 |
+
help="Disable logging attack results to stdout",
|
| 373 |
+
)
|
| 374 |
+
parser.add_argument(
|
| 375 |
+
"--silent",
|
| 376 |
+
action="store_true",
|
| 377 |
+
default=default_obj.silent,
|
| 378 |
+
help="Disable all logging",
|
| 379 |
+
)
|
| 380 |
+
parser.add_argument(
|
| 381 |
+
"--enable-advance-metrics",
|
| 382 |
+
action="store_true",
|
| 383 |
+
default=default_obj.enable_advance_metrics,
|
| 384 |
+
help="Enable calculation and display of optional advance post-hoc metrics like perplexity, USE distance, etc.",
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
return parser
|
| 388 |
+
|
| 389 |
+
@classmethod
|
| 390 |
+
def create_loggers_from_args(cls, args):
|
| 391 |
+
"""Creates AttackLogManager from an AttackArgs object."""
|
| 392 |
+
assert isinstance(
|
| 393 |
+
args, cls
|
| 394 |
+
), f"Expect args to be of type `{type(cls)}`, but got type `{type(args)}`."
|
| 395 |
+
|
| 396 |
+
# Create logger
|
| 397 |
+
attack_log_manager = textattack.loggers.AttackLogManager(args.metrics)
|
| 398 |
+
|
| 399 |
+
# Get current time for file naming
|
| 400 |
+
timestamp = time.strftime("%Y-%m-%d-%H-%M")
|
| 401 |
+
|
| 402 |
+
# if '--log-to-txt' specified with arguments
|
| 403 |
+
if args.log_to_txt is not None:
|
| 404 |
+
if args.log_to_txt.lower().endswith(".txt"):
|
| 405 |
+
txt_file_path = args.log_to_txt
|
| 406 |
+
else:
|
| 407 |
+
txt_file_path = os.path.join(args.log_to_txt, f"{timestamp}-log.txt")
|
| 408 |
+
|
| 409 |
+
dir_path = os.path.dirname(txt_file_path)
|
| 410 |
+
dir_path = dir_path if dir_path else "."
|
| 411 |
+
if not os.path.exists(dir_path):
|
| 412 |
+
os.makedirs(os.path.dirname(txt_file_path))
|
| 413 |
+
|
| 414 |
+
color_method = "file"
|
| 415 |
+
attack_log_manager.add_output_file(txt_file_path, color_method)
|
| 416 |
+
|
| 417 |
+
# if '--log-to-csv' specified with arguments
|
| 418 |
+
if args.log_to_csv is not None:
|
| 419 |
+
if args.log_to_csv.lower().endswith(".csv"):
|
| 420 |
+
csv_file_path = args.log_to_csv
|
| 421 |
+
else:
|
| 422 |
+
csv_file_path = os.path.join(args.log_to_csv, f"{timestamp}-log.csv")
|
| 423 |
+
|
| 424 |
+
dir_path = os.path.dirname(csv_file_path)
|
| 425 |
+
dir_path = dir_path if dir_path else "."
|
| 426 |
+
if not os.path.exists(dir_path):
|
| 427 |
+
os.makedirs(dir_path)
|
| 428 |
+
|
| 429 |
+
color_method = (
|
| 430 |
+
None if args.csv_coloring_style == "plain" else args.csv_coloring_style
|
| 431 |
+
)
|
| 432 |
+
attack_log_manager.add_output_csv(csv_file_path, color_method)
|
| 433 |
+
|
| 434 |
+
# if '--log-summary-to-json' specified with arguments
|
| 435 |
+
if args.log_summary_to_json is not None:
|
| 436 |
+
if args.log_summary_to_json.lower().endswith(".json"):
|
| 437 |
+
summary_json_file_path = args.log_summary_to_json
|
| 438 |
+
else:
|
| 439 |
+
summary_json_file_path = os.path.join(
|
| 440 |
+
args.log_summary_to_json, f"{timestamp}-attack_summary_log.json"
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
dir_path = os.path.dirname(summary_json_file_path)
|
| 444 |
+
dir_path = dir_path if dir_path else "."
|
| 445 |
+
if not os.path.exists(dir_path):
|
| 446 |
+
os.makedirs(os.path.dirname(summary_json_file_path))
|
| 447 |
+
|
| 448 |
+
attack_log_manager.add_output_summary_json(summary_json_file_path)
|
| 449 |
+
|
| 450 |
+
# Visdom
|
| 451 |
+
if args.log_to_visdom is not None:
|
| 452 |
+
attack_log_manager.enable_visdom(**args.log_to_visdom)
|
| 453 |
+
|
| 454 |
+
# Weights & Biases
|
| 455 |
+
if args.log_to_wandb is not None:
|
| 456 |
+
attack_log_manager.enable_wandb(**args.log_to_wandb)
|
| 457 |
+
|
| 458 |
+
# Stdout
|
| 459 |
+
if not args.disable_stdout and not sys.stdout.isatty():
|
| 460 |
+
attack_log_manager.disable_color()
|
| 461 |
+
elif not args.disable_stdout:
|
| 462 |
+
attack_log_manager.enable_stdout()
|
| 463 |
+
|
| 464 |
+
return attack_log_manager
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
@dataclass
|
| 468 |
+
class _CommandLineAttackArgs:
|
| 469 |
+
"""Attack args for command line execution. This requires more arguments to
|
| 470 |
+
create ``Attack`` object as specified.
|
| 471 |
+
|
| 472 |
+
Args:
|
| 473 |
+
transformation (:obj:`str`, `optional`, defaults to :obj:`"word-swap-embedding"`):
|
| 474 |
+
Name of transformation to use.
|
| 475 |
+
constraints (:obj:`list[str]`, `optional`, defaults to :obj:`["repeat", "stopword"]`):
|
| 476 |
+
List of names of constraints to use.
|
| 477 |
+
goal_function (:obj:`str`, `optional`, defaults to :obj:`"untargeted-classification"`):
|
| 478 |
+
Name of goal function to use.
|
| 479 |
+
search_method (:obj:`str`, `optional`, defualts to :obj:`"greedy-word-wir"`):
|
| 480 |
+
Name of search method to use.
|
| 481 |
+
attack_recipe (:obj:`str`, `optional`, defaults to :obj:`None`):
|
| 482 |
+
Name of attack recipe to use.
|
| 483 |
+
.. note::
|
| 484 |
+
Setting this overrides any previous selection of transformation, constraints, goal function, and search method.
|
| 485 |
+
attack_from_file (:obj:`str`, `optional`, defaults to :obj:`None`):
|
| 486 |
+
Path of `.py` file from which to load attack from. Use `<path>^<variable_name>` to specifiy which variable to import from the file.
|
| 487 |
+
.. note::
|
| 488 |
+
If this is set, it overrides any previous selection of transformation, constraints, goal function, and search method
|
| 489 |
+
interactive (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 490 |
+
If `True`, carry attack in interactive mode.
|
| 491 |
+
parallel (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 492 |
+
If `True`, attack in parallel.
|
| 493 |
+
model_batch_size (:obj:`int`, `optional`, defaults to :obj:`32`):
|
| 494 |
+
The batch size for making queries to the victim model.
|
| 495 |
+
model_cache_size (:obj:`int`, `optional`, defaults to :obj:`2**18`):
|
| 496 |
+
The maximum number of items to keep in the model results cache at once.
|
| 497 |
+
constraint-cache-size (:obj:`int`, `optional`, defaults to :obj:`2**18`):
|
| 498 |
+
The maximum number of items to keep in the constraints cache at once.
|
| 499 |
+
"""
|
| 500 |
+
|
| 501 |
+
transformation: str = "word-swap-embedding"
|
| 502 |
+
constraints: list = field(default_factory=lambda: ["repeat", "stopword"])
|
| 503 |
+
goal_function: str = "untargeted-classification"
|
| 504 |
+
search_method: str = "greedy-word-wir"
|
| 505 |
+
attack_recipe: str = None
|
| 506 |
+
attack_from_file: str = None
|
| 507 |
+
interactive: bool = False
|
| 508 |
+
parallel: bool = False
|
| 509 |
+
model_batch_size: int = 32
|
| 510 |
+
model_cache_size: int = 2**18
|
| 511 |
+
constraint_cache_size: int = 2**18
|
| 512 |
+
|
| 513 |
+
@classmethod
|
| 514 |
+
def _add_parser_args(cls, parser):
|
| 515 |
+
"""Add listed args to command line parser."""
|
| 516 |
+
default_obj = cls()
|
| 517 |
+
transformation_names = set(BLACK_BOX_TRANSFORMATION_CLASS_NAMES.keys()) | set(
|
| 518 |
+
WHITE_BOX_TRANSFORMATION_CLASS_NAMES.keys()
|
| 519 |
+
)
|
| 520 |
+
parser.add_argument(
|
| 521 |
+
"--transformation",
|
| 522 |
+
type=str,
|
| 523 |
+
required=False,
|
| 524 |
+
default=default_obj.transformation,
|
| 525 |
+
help='The transformation to apply. Usage: "--transformation {transformation}:{arg_1}={value_1},{arg_3}={value_3}". Choices: '
|
| 526 |
+
+ str(transformation_names),
|
| 527 |
+
)
|
| 528 |
+
parser.add_argument(
|
| 529 |
+
"--constraints",
|
| 530 |
+
type=str,
|
| 531 |
+
required=False,
|
| 532 |
+
nargs="*",
|
| 533 |
+
default=default_obj.constraints,
|
| 534 |
+
help='Constraints to add to the attack. Usage: "--constraints {constraint}:{arg_1}={value_1},{arg_3}={value_3}". Choices: '
|
| 535 |
+
+ str(CONSTRAINT_CLASS_NAMES.keys()),
|
| 536 |
+
)
|
| 537 |
+
goal_function_choices = ", ".join(GOAL_FUNCTION_CLASS_NAMES.keys())
|
| 538 |
+
parser.add_argument(
|
| 539 |
+
"--goal-function",
|
| 540 |
+
"-g",
|
| 541 |
+
default=default_obj.goal_function,
|
| 542 |
+
help=f"The goal function to use. choices: {goal_function_choices}",
|
| 543 |
+
)
|
| 544 |
+
attack_group = parser.add_mutually_exclusive_group(required=False)
|
| 545 |
+
search_choices = ", ".join(SEARCH_METHOD_CLASS_NAMES.keys())
|
| 546 |
+
attack_group.add_argument(
|
| 547 |
+
"--search-method",
|
| 548 |
+
"--search",
|
| 549 |
+
"-s",
|
| 550 |
+
type=str,
|
| 551 |
+
required=False,
|
| 552 |
+
default=default_obj.search_method,
|
| 553 |
+
help=f"The search method to use. choices: {search_choices}",
|
| 554 |
+
)
|
| 555 |
+
attack_group.add_argument(
|
| 556 |
+
"--attack-recipe",
|
| 557 |
+
"--recipe",
|
| 558 |
+
"-r",
|
| 559 |
+
type=str,
|
| 560 |
+
required=False,
|
| 561 |
+
default=default_obj.attack_recipe,
|
| 562 |
+
help="full attack recipe (overrides provided goal function, transformation & constraints)",
|
| 563 |
+
choices=ATTACK_RECIPE_NAMES.keys(),
|
| 564 |
+
)
|
| 565 |
+
attack_group.add_argument(
|
| 566 |
+
"--attack-from-file",
|
| 567 |
+
type=str,
|
| 568 |
+
required=False,
|
| 569 |
+
default=default_obj.attack_from_file,
|
| 570 |
+
help="Path of `.py` file from which to load attack from. Use `<path>^<variable_name>` to specifiy which variable to import from the file.",
|
| 571 |
+
)
|
| 572 |
+
parser.add_argument(
|
| 573 |
+
"--interactive",
|
| 574 |
+
action="store_true",
|
| 575 |
+
default=default_obj.interactive,
|
| 576 |
+
help="Whether to run attacks interactively.",
|
| 577 |
+
)
|
| 578 |
+
parser.add_argument(
|
| 579 |
+
"--model-batch-size",
|
| 580 |
+
type=int,
|
| 581 |
+
default=default_obj.model_batch_size,
|
| 582 |
+
help="The batch size for making calls to the model.",
|
| 583 |
+
)
|
| 584 |
+
parser.add_argument(
|
| 585 |
+
"--model-cache-size",
|
| 586 |
+
type=int,
|
| 587 |
+
default=default_obj.model_cache_size,
|
| 588 |
+
help="The maximum number of items to keep in the model results cache at once.",
|
| 589 |
+
)
|
| 590 |
+
parser.add_argument(
|
| 591 |
+
"--constraint-cache-size",
|
| 592 |
+
type=int,
|
| 593 |
+
default=default_obj.constraint_cache_size,
|
| 594 |
+
help="The maximum number of items to keep in the constraints cache at once.",
|
| 595 |
+
)
|
| 596 |
+
|
| 597 |
+
return parser
|
| 598 |
+
|
| 599 |
+
@classmethod
|
| 600 |
+
def _create_transformation_from_args(cls, args, model_wrapper):
|
| 601 |
+
"""Create `Transformation` based on provided `args` and
|
| 602 |
+
`model_wrapper`."""
|
| 603 |
+
|
| 604 |
+
transformation_name = args.transformation
|
| 605 |
+
if ARGS_SPLIT_TOKEN in transformation_name:
|
| 606 |
+
transformation_name, params = transformation_name.split(ARGS_SPLIT_TOKEN)
|
| 607 |
+
|
| 608 |
+
if transformation_name in WHITE_BOX_TRANSFORMATION_CLASS_NAMES:
|
| 609 |
+
transformation = eval(
|
| 610 |
+
f"{WHITE_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}(model_wrapper.model, {params})"
|
| 611 |
+
)
|
| 612 |
+
elif transformation_name in BLACK_BOX_TRANSFORMATION_CLASS_NAMES:
|
| 613 |
+
transformation = eval(
|
| 614 |
+
f"{BLACK_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}({params})"
|
| 615 |
+
)
|
| 616 |
+
else:
|
| 617 |
+
raise ValueError(
|
| 618 |
+
f"Error: unsupported transformation {transformation_name}"
|
| 619 |
+
)
|
| 620 |
+
else:
|
| 621 |
+
if transformation_name in WHITE_BOX_TRANSFORMATION_CLASS_NAMES:
|
| 622 |
+
transformation = eval(
|
| 623 |
+
f"{WHITE_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}(model_wrapper.model)"
|
| 624 |
+
)
|
| 625 |
+
elif transformation_name in BLACK_BOX_TRANSFORMATION_CLASS_NAMES:
|
| 626 |
+
transformation = eval(
|
| 627 |
+
f"{BLACK_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}()"
|
| 628 |
+
)
|
| 629 |
+
else:
|
| 630 |
+
raise ValueError(
|
| 631 |
+
f"Error: unsupported transformation {transformation_name}"
|
| 632 |
+
)
|
| 633 |
+
return transformation
|
| 634 |
+
|
| 635 |
+
@classmethod
|
| 636 |
+
def _create_goal_function_from_args(cls, args, model_wrapper):
|
| 637 |
+
"""Create `GoalFunction` based on provided `args` and
|
| 638 |
+
`model_wrapper`."""
|
| 639 |
+
|
| 640 |
+
goal_function = args.goal_function
|
| 641 |
+
if ARGS_SPLIT_TOKEN in goal_function:
|
| 642 |
+
goal_function_name, params = goal_function.split(ARGS_SPLIT_TOKEN)
|
| 643 |
+
if goal_function_name not in GOAL_FUNCTION_CLASS_NAMES:
|
| 644 |
+
raise ValueError(
|
| 645 |
+
f"Error: unsupported goal_function {goal_function_name}"
|
| 646 |
+
)
|
| 647 |
+
goal_function = eval(
|
| 648 |
+
f"{GOAL_FUNCTION_CLASS_NAMES[goal_function_name]}(model_wrapper, {params})"
|
| 649 |
+
)
|
| 650 |
+
elif goal_function in GOAL_FUNCTION_CLASS_NAMES:
|
| 651 |
+
goal_function = eval(
|
| 652 |
+
f"{GOAL_FUNCTION_CLASS_NAMES[goal_function]}(model_wrapper)"
|
| 653 |
+
)
|
| 654 |
+
else:
|
| 655 |
+
raise ValueError(f"Error: unsupported goal_function {goal_function}")
|
| 656 |
+
if args.query_budget:
|
| 657 |
+
goal_function.query_budget = args.query_budget
|
| 658 |
+
goal_function.model_cache_size = args.model_cache_size
|
| 659 |
+
goal_function.batch_size = args.model_batch_size
|
| 660 |
+
return goal_function
|
| 661 |
+
|
| 662 |
+
@classmethod
|
| 663 |
+
def _create_constraints_from_args(cls, args):
|
| 664 |
+
"""Create list of `Constraints` based on provided `args`."""
|
| 665 |
+
|
| 666 |
+
if not args.constraints:
|
| 667 |
+
return []
|
| 668 |
+
|
| 669 |
+
_constraints = []
|
| 670 |
+
for constraint in args.constraints:
|
| 671 |
+
if ARGS_SPLIT_TOKEN in constraint:
|
| 672 |
+
constraint_name, params = constraint.split(ARGS_SPLIT_TOKEN)
|
| 673 |
+
if constraint_name not in CONSTRAINT_CLASS_NAMES:
|
| 674 |
+
raise ValueError(f"Error: unsupported constraint {constraint_name}")
|
| 675 |
+
_constraints.append(
|
| 676 |
+
eval(f"{CONSTRAINT_CLASS_NAMES[constraint_name]}({params})")
|
| 677 |
+
)
|
| 678 |
+
elif constraint in CONSTRAINT_CLASS_NAMES:
|
| 679 |
+
_constraints.append(eval(f"{CONSTRAINT_CLASS_NAMES[constraint]}()"))
|
| 680 |
+
else:
|
| 681 |
+
raise ValueError(f"Error: unsupported constraint {constraint}")
|
| 682 |
+
|
| 683 |
+
return _constraints
|
| 684 |
+
|
| 685 |
+
@classmethod
|
| 686 |
+
def _create_attack_from_args(cls, args, model_wrapper):
|
| 687 |
+
"""Given ``CommandLineArgs`` and ``ModelWrapper``, return specified
|
| 688 |
+
``Attack`` object."""
|
| 689 |
+
|
| 690 |
+
assert isinstance(
|
| 691 |
+
args, cls
|
| 692 |
+
), f"Expect args to be of type `{type(cls)}`, but got type `{type(args)}`."
|
| 693 |
+
|
| 694 |
+
if args.attack_recipe:
|
| 695 |
+
if ARGS_SPLIT_TOKEN in args.attack_recipe:
|
| 696 |
+
recipe_name, params = args.attack_recipe.split(ARGS_SPLIT_TOKEN)
|
| 697 |
+
if recipe_name not in ATTACK_RECIPE_NAMES:
|
| 698 |
+
raise ValueError(f"Error: unsupported recipe {recipe_name}")
|
| 699 |
+
recipe = eval(
|
| 700 |
+
f"{ATTACK_RECIPE_NAMES[recipe_name]}.build(model_wrapper, {params})"
|
| 701 |
+
)
|
| 702 |
+
elif args.attack_recipe in ATTACK_RECIPE_NAMES:
|
| 703 |
+
recipe = eval(
|
| 704 |
+
f"{ATTACK_RECIPE_NAMES[args.attack_recipe]}.build(model_wrapper)"
|
| 705 |
+
)
|
| 706 |
+
else:
|
| 707 |
+
raise ValueError(f"Invalid recipe {args.attack_recipe}")
|
| 708 |
+
if args.query_budget:
|
| 709 |
+
recipe.goal_function.query_budget = args.query_budget
|
| 710 |
+
recipe.goal_function.model_cache_size = args.model_cache_size
|
| 711 |
+
recipe.constraint_cache_size = args.constraint_cache_size
|
| 712 |
+
return recipe
|
| 713 |
+
elif args.attack_from_file:
|
| 714 |
+
if ARGS_SPLIT_TOKEN in args.attack_from_file:
|
| 715 |
+
attack_file, attack_name = args.attack_from_file.split(ARGS_SPLIT_TOKEN)
|
| 716 |
+
else:
|
| 717 |
+
attack_file, attack_name = args.attack_from_file, "attack"
|
| 718 |
+
attack_module = load_module_from_file(attack_file)
|
| 719 |
+
if not hasattr(attack_module, attack_name):
|
| 720 |
+
raise ValueError(
|
| 721 |
+
f"Loaded `{attack_file}` but could not find `{attack_name}`."
|
| 722 |
+
)
|
| 723 |
+
attack_func = getattr(attack_module, attack_name)
|
| 724 |
+
return attack_func(model_wrapper)
|
| 725 |
+
else:
|
| 726 |
+
goal_function = cls._create_goal_function_from_args(args, model_wrapper)
|
| 727 |
+
transformation = cls._create_transformation_from_args(args, model_wrapper)
|
| 728 |
+
constraints = cls._create_constraints_from_args(args)
|
| 729 |
+
if ARGS_SPLIT_TOKEN in args.search_method:
|
| 730 |
+
search_name, params = args.search_method.split(ARGS_SPLIT_TOKEN)
|
| 731 |
+
if search_name not in SEARCH_METHOD_CLASS_NAMES:
|
| 732 |
+
raise ValueError(f"Error: unsupported search {search_name}")
|
| 733 |
+
search_method = eval(
|
| 734 |
+
f"{SEARCH_METHOD_CLASS_NAMES[search_name]}({params})"
|
| 735 |
+
)
|
| 736 |
+
elif args.search_method in SEARCH_METHOD_CLASS_NAMES:
|
| 737 |
+
search_method = eval(
|
| 738 |
+
f"{SEARCH_METHOD_CLASS_NAMES[args.search_method]}()"
|
| 739 |
+
)
|
| 740 |
+
else:
|
| 741 |
+
raise ValueError(f"Error: unsupported attack {args.search_method}")
|
| 742 |
+
|
| 743 |
+
return Attack(
|
| 744 |
+
goal_function,
|
| 745 |
+
constraints,
|
| 746 |
+
transformation,
|
| 747 |
+
search_method,
|
| 748 |
+
constraint_cache_size=args.constraint_cache_size,
|
| 749 |
+
)
|
| 750 |
+
|
| 751 |
+
|
| 752 |
+
# This neat trick allows use to reorder the arguments to avoid TypeErrors commonly found when inheriting dataclass.
|
| 753 |
+
# https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses
|
| 754 |
+
@dataclass
|
| 755 |
+
class CommandLineAttackArgs(AttackArgs, _CommandLineAttackArgs, DatasetArgs, ModelArgs):
|
| 756 |
+
@classmethod
|
| 757 |
+
def _add_parser_args(cls, parser):
|
| 758 |
+
"""Add listed args to command line parser."""
|
| 759 |
+
parser = ModelArgs._add_parser_args(parser)
|
| 760 |
+
parser = DatasetArgs._add_parser_args(parser)
|
| 761 |
+
parser = _CommandLineAttackArgs._add_parser_args(parser)
|
| 762 |
+
parser = AttackArgs._add_parser_args(parser)
|
| 763 |
+
return parser
|
textattack/attack_recipes/__init__.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""".. _attack_recipes:
|
| 2 |
+
|
| 3 |
+
Attack Recipes Package:
|
| 4 |
+
========================
|
| 5 |
+
|
| 6 |
+
We provide a number of pre-built attack recipes, which correspond to attacks from the literature. To run an attack recipe from the command line, run::
|
| 7 |
+
|
| 8 |
+
textattack attack --recipe [recipe_name]
|
| 9 |
+
|
| 10 |
+
To initialize an attack in Python script, use::
|
| 11 |
+
|
| 12 |
+
<recipe name>.build(model_wrapper)
|
| 13 |
+
|
| 14 |
+
For example, ``attack = InputReductionFeng2018.build(model)`` creates `attack`, an object of type ``Attack`` with the goal function, transformation, constraints, and search method specified in that paper. This object can then be used just like any other attack; for example, by calling ``attack.attack_dataset``.
|
| 15 |
+
|
| 16 |
+
TextAttack supports the following attack recipes (each recipe's documentation contains a link to the corresponding paper):
|
| 17 |
+
|
| 18 |
+
.. contents:: :local:
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
from .attack_recipe import AttackRecipe
|
| 22 |
+
|
| 23 |
+
from .a2t_yoo_2021 import A2TYoo2021
|
| 24 |
+
from .bae_garg_2019 import BAEGarg2019
|
| 25 |
+
from .bert_attack_li_2020 import BERTAttackLi2020
|
| 26 |
+
from .genetic_algorithm_alzantot_2018 import GeneticAlgorithmAlzantot2018
|
| 27 |
+
from .faster_genetic_algorithm_jia_2019 import FasterGeneticAlgorithmJia2019
|
| 28 |
+
from .deepwordbug_gao_2018 import DeepWordBugGao2018
|
| 29 |
+
from .hotflip_ebrahimi_2017 import HotFlipEbrahimi2017
|
| 30 |
+
from .input_reduction_feng_2018 import InputReductionFeng2018
|
| 31 |
+
from .kuleshov_2017 import Kuleshov2017
|
| 32 |
+
from .morpheus_tan_2020 import MorpheusTan2020
|
| 33 |
+
from .seq2sick_cheng_2018_blackbox import Seq2SickCheng2018BlackBox
|
| 34 |
+
from .textbugger_li_2018 import TextBuggerLi2018
|
| 35 |
+
from .textfooler_jin_2019 import TextFoolerJin2019
|
| 36 |
+
from .pwws_ren_2019 import PWWSRen2019
|
| 37 |
+
from .iga_wang_2019 import IGAWang2019
|
| 38 |
+
from .pruthi_2019 import Pruthi2019
|
| 39 |
+
from .pso_zang_2020 import PSOZang2020
|
| 40 |
+
from .checklist_ribeiro_2020 import CheckList2020
|
| 41 |
+
from .clare_li_2020 import CLARE2020
|
| 42 |
+
from .french_recipe import FrenchRecipe
|
| 43 |
+
from .spanish_recipe import SpanishRecipe
|
textattack/attack_recipes/a2t_yoo_2021.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
A2T (A2T: Attack for Adversarial Training Recipe)
|
| 3 |
+
==================================================
|
| 4 |
+
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from textattack import Attack
|
| 8 |
+
from textattack.constraints.grammaticality import PartOfSpeech
|
| 9 |
+
from textattack.constraints.pre_transformation import (
|
| 10 |
+
InputColumnModification,
|
| 11 |
+
MaxModificationRate,
|
| 12 |
+
RepeatModification,
|
| 13 |
+
StopwordModification,
|
| 14 |
+
)
|
| 15 |
+
from textattack.constraints.semantics import WordEmbeddingDistance
|
| 16 |
+
from textattack.constraints.semantics.sentence_encoders import BERT
|
| 17 |
+
from textattack.goal_functions import UntargetedClassification
|
| 18 |
+
from textattack.search_methods import GreedyWordSwapWIR
|
| 19 |
+
from textattack.transformations import WordSwapEmbedding, WordSwapMaskedLM
|
| 20 |
+
|
| 21 |
+
from .attack_recipe import AttackRecipe
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class A2TYoo2021(AttackRecipe):
|
| 25 |
+
"""Towards Improving Adversarial Training of NLP Models.
|
| 26 |
+
|
| 27 |
+
(Yoo et al., 2021)
|
| 28 |
+
|
| 29 |
+
https://arxiv.org/abs/2109.00544
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
@staticmethod
|
| 33 |
+
def build(model_wrapper, mlm=False):
|
| 34 |
+
"""Build attack recipe.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
model_wrapper (:class:`~textattack.models.wrappers.ModelWrapper`):
|
| 38 |
+
Model wrapper containing both the model and the tokenizer.
|
| 39 |
+
mlm (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 40 |
+
If :obj:`True`, load `A2T-MLM` attack. Otherwise, load regular `A2T` attack.
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
:class:`~textattack.Attack`: A2T attack.
|
| 44 |
+
"""
|
| 45 |
+
constraints = [RepeatModification(), StopwordModification()]
|
| 46 |
+
input_column_modification = InputColumnModification(
|
| 47 |
+
["premise", "hypothesis"], {"premise"}
|
| 48 |
+
)
|
| 49 |
+
constraints.append(input_column_modification)
|
| 50 |
+
constraints.append(PartOfSpeech(allow_verb_noun_swap=False))
|
| 51 |
+
constraints.append(MaxModificationRate(max_rate=0.1, min_threshold=4))
|
| 52 |
+
sent_encoder = BERT(
|
| 53 |
+
model_name="stsb-distilbert-base", threshold=0.9, metric="cosine"
|
| 54 |
+
)
|
| 55 |
+
constraints.append(sent_encoder)
|
| 56 |
+
|
| 57 |
+
if mlm:
|
| 58 |
+
transformation = transformation = WordSwapMaskedLM(
|
| 59 |
+
method="bae", max_candidates=20, min_confidence=0.0, batch_size=16
|
| 60 |
+
)
|
| 61 |
+
else:
|
| 62 |
+
transformation = WordSwapEmbedding(max_candidates=20)
|
| 63 |
+
constraints.append(WordEmbeddingDistance(min_cos_sim=0.8))
|
| 64 |
+
|
| 65 |
+
#
|
| 66 |
+
# Goal is untargeted classification
|
| 67 |
+
#
|
| 68 |
+
goal_function = UntargetedClassification(model_wrapper, model_batch_size=32)
|
| 69 |
+
#
|
| 70 |
+
# Greedily swap words with "Word Importance Ranking".
|
| 71 |
+
#
|
| 72 |
+
search_method = GreedyWordSwapWIR(wir_method="gradient")
|
| 73 |
+
|
| 74 |
+
return Attack(goal_function, constraints, transformation, search_method)
|
textattack/attack_recipes/attack_recipe.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Attack Recipe Class
|
| 3 |
+
========================
|
| 4 |
+
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from abc import ABC, abstractmethod
|
| 8 |
+
|
| 9 |
+
from textattack import Attack
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class AttackRecipe(Attack, ABC):
|
| 13 |
+
"""A recipe for building an NLP adversarial attack from the literature."""
|
| 14 |
+
|
| 15 |
+
@staticmethod
|
| 16 |
+
@abstractmethod
|
| 17 |
+
def build(model_wrapper, **kwargs):
|
| 18 |
+
"""Creates pre-built :class:`~textattack.Attack` that correspond to
|
| 19 |
+
attacks from the literature.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
model_wrapper (:class:`~textattack.models.wrappers.ModelWrapper`):
|
| 23 |
+
:class:`~textattack.models.wrappers.ModelWrapper` that contains the victim model and tokenizer.
|
| 24 |
+
This is passed to :class:`~textattack.goal_functions.GoalFunction` when constructing the attack.
|
| 25 |
+
kwargs:
|
| 26 |
+
Additional keyword arguments.
|
| 27 |
+
Returns:
|
| 28 |
+
:class:`~textattack.Attack`
|
| 29 |
+
"""
|
| 30 |
+
raise NotImplementedError()
|
textattack/attack_recipes/bae_garg_2019.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
BAE (BAE: BERT-Based Adversarial Examples)
|
| 3 |
+
============================================
|
| 4 |
+
|
| 5 |
+
"""
|
| 6 |
+
from textattack.constraints.grammaticality import PartOfSpeech
|
| 7 |
+
from textattack.constraints.pre_transformation import (
|
| 8 |
+
RepeatModification,
|
| 9 |
+
StopwordModification,
|
| 10 |
+
)
|
| 11 |
+
from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder
|
| 12 |
+
from textattack.goal_functions import UntargetedClassification
|
| 13 |
+
from textattack.search_methods import GreedyWordSwapWIR
|
| 14 |
+
from textattack.transformations import WordSwapMaskedLM
|
| 15 |
+
|
| 16 |
+
from .attack_recipe import AttackRecipe
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class BAEGarg2019(AttackRecipe):
|
| 20 |
+
"""Siddhant Garg and Goutham Ramakrishnan, 2019.
|
| 21 |
+
|
| 22 |
+
BAE: BERT-based Adversarial Examples for Text Classification.
|
| 23 |
+
|
| 24 |
+
https://arxiv.org/pdf/2004.01970
|
| 25 |
+
|
| 26 |
+
This is "attack mode" 1 from the paper, BAE-R, word replacement.
|
| 27 |
+
|
| 28 |
+
We present 4 attack modes for BAE based on the
|
| 29 |
+
R and I operations, where for each token t in S:
|
| 30 |
+
• BAE-R: Replace token t (See Algorithm 1)
|
| 31 |
+
• BAE-I: Insert a token to the left or right of t
|
| 32 |
+
• BAE-R/I: Either replace token t or insert a
|
| 33 |
+
token to the left or right of t
|
| 34 |
+
• BAE-R+I: First replace token t, then insert a
|
| 35 |
+
token to the left or right of t
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
@staticmethod
|
| 39 |
+
def build(model_wrapper):
|
| 40 |
+
# "In this paper, we present a simple yet novel technique: BAE (BERT-based
|
| 41 |
+
# Adversarial Examples), which uses a language model (LM) for token
|
| 42 |
+
# replacement to best fit the overall context. We perturb an input sentence
|
| 43 |
+
# by either replacing a token or inserting a new token in the sentence, by
|
| 44 |
+
# means of masking a part of the input and using a LM to fill in the mask."
|
| 45 |
+
#
|
| 46 |
+
# We only consider the top K=50 synonyms from the MLM predictions.
|
| 47 |
+
#
|
| 48 |
+
# [from email correspondance with the author]
|
| 49 |
+
# "When choosing the top-K candidates from the BERT masked LM, we filter out
|
| 50 |
+
# the sub-words and only retain the whole words (by checking if they are
|
| 51 |
+
# present in the GloVE vocabulary)"
|
| 52 |
+
#
|
| 53 |
+
transformation = WordSwapMaskedLM(
|
| 54 |
+
method="bae", max_candidates=50, min_confidence=0.0
|
| 55 |
+
)
|
| 56 |
+
#
|
| 57 |
+
# Don't modify the same word twice or stopwords.
|
| 58 |
+
#
|
| 59 |
+
constraints = [RepeatModification(), StopwordModification()]
|
| 60 |
+
|
| 61 |
+
# For the R operations we add an additional check for
|
| 62 |
+
# grammatical correctness of the generated adversarial example by filtering
|
| 63 |
+
# out predicted tokens that do not form the same part of speech (POS) as the
|
| 64 |
+
# original token t_i in the sentence.
|
| 65 |
+
constraints.append(PartOfSpeech(allow_verb_noun_swap=True))
|
| 66 |
+
|
| 67 |
+
# "To ensure semantic similarity on introducing perturbations in the input
|
| 68 |
+
# text, we filter the set of top-K masked tokens (K is a pre-defined
|
| 69 |
+
# constant) predicted by BERT-MLM using a Universal Sentence Encoder (USE)
|
| 70 |
+
# (Cer et al., 2018)-based sentence similarity scorer."
|
| 71 |
+
#
|
| 72 |
+
# "[We] set a threshold of 0.8 for the cosine similarity between USE-based
|
| 73 |
+
# embeddings of the adversarial and input text."
|
| 74 |
+
#
|
| 75 |
+
# [from email correspondence with the author]
|
| 76 |
+
# "For a fair comparison of the benefits of using a BERT-MLM in our paper,
|
| 77 |
+
# we retained the majority of TextFooler's specifications. Thus we:
|
| 78 |
+
# 1. Use the USE for comparison within a window of size 15 around the word
|
| 79 |
+
# being replaced/inserted.
|
| 80 |
+
# 2. Set the similarity score threshold to 0.1 for inputs shorter than the
|
| 81 |
+
# window size (this translates roughly to almost always accepting the new text).
|
| 82 |
+
# 3. Perform the USE similarity thresholding of 0.8 with respect to the text
|
| 83 |
+
# just before the replacement/insertion and not the original text (For
|
| 84 |
+
# example: at the 3rd R/I operation, we compute the USE score on a window
|
| 85 |
+
# of size 15 of the text obtained after the first 2 R/I operations and not
|
| 86 |
+
# the original text).
|
| 87 |
+
# ...
|
| 88 |
+
# To address point (3) from above, compare the USE with the original text
|
| 89 |
+
# at each iteration instead of the current one (While doing this change
|
| 90 |
+
# for the R-operation is trivial, doing it for the I-operation with the
|
| 91 |
+
# window based USE comparison might be more involved)."
|
| 92 |
+
#
|
| 93 |
+
# Finally, since the BAE code is based on the TextFooler code, we need to
|
| 94 |
+
# adjust the threshold to account for the missing / pi in the cosine
|
| 95 |
+
# similarity comparison. So the final threshold is 1 - (1 - 0.8) / pi
|
| 96 |
+
# = 1 - (0.2 / pi) = 0.936338023.
|
| 97 |
+
use_constraint = UniversalSentenceEncoder(
|
| 98 |
+
threshold=0.936338023,
|
| 99 |
+
metric="cosine",
|
| 100 |
+
compare_against_original=True,
|
| 101 |
+
window_size=15,
|
| 102 |
+
skip_text_shorter_than_window=True,
|
| 103 |
+
)
|
| 104 |
+
constraints.append(use_constraint)
|
| 105 |
+
#
|
| 106 |
+
# Goal is untargeted classification.
|
| 107 |
+
#
|
| 108 |
+
goal_function = UntargetedClassification(model_wrapper)
|
| 109 |
+
#
|
| 110 |
+
# "We estimate the token importance Ii of each token
|
| 111 |
+
# t_i ∈ S = [t1, . . . , tn], by deleting ti from S and computing the
|
| 112 |
+
# decrease in probability of predicting the correct label y, similar
|
| 113 |
+
# to (Jin et al., 2019).
|
| 114 |
+
#
|
| 115 |
+
# • "If there are multiple tokens can cause C to misclassify S when they
|
| 116 |
+
# replace the mask, we choose the token which makes Sadv most similar to
|
| 117 |
+
# the original S based on the USE score."
|
| 118 |
+
# • "If no token causes misclassification, we choose the perturbation that
|
| 119 |
+
# decreases the prediction probability P(C(Sadv)=y) the most."
|
| 120 |
+
#
|
| 121 |
+
search_method = GreedyWordSwapWIR(wir_method="delete")
|
| 122 |
+
|
| 123 |
+
return BAEGarg2019(goal_function, constraints, transformation, search_method)
|
textattack/attack_recipes/bert_attack_li_2020.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
BERT-Attack:
|
| 3 |
+
============================================================
|
| 4 |
+
|
| 5 |
+
(BERT-Attack: Adversarial Attack Against BERT Using BERT)
|
| 6 |
+
|
| 7 |
+
.. warning::
|
| 8 |
+
This attack is super slow
|
| 9 |
+
(see https://github.com/QData/TextAttack/issues/586)
|
| 10 |
+
Consider using smaller values for "max_candidates".
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
from textattack import Attack
|
| 14 |
+
from textattack.constraints.overlap import MaxWordsPerturbed
|
| 15 |
+
from textattack.constraints.pre_transformation import (
|
| 16 |
+
RepeatModification,
|
| 17 |
+
StopwordModification,
|
| 18 |
+
)
|
| 19 |
+
from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder
|
| 20 |
+
from textattack.goal_functions import UntargetedClassification
|
| 21 |
+
from textattack.search_methods import GreedyWordSwapWIR
|
| 22 |
+
from textattack.transformations import WordSwapMaskedLM
|
| 23 |
+
|
| 24 |
+
from .attack_recipe import AttackRecipe
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class BERTAttackLi2020(AttackRecipe):
|
| 28 |
+
"""Li, L.., Ma, R., Guo, Q., Xiangyang, X., Xipeng, Q. (2020).
|
| 29 |
+
|
| 30 |
+
BERT-ATTACK: Adversarial Attack Against BERT Using BERT
|
| 31 |
+
|
| 32 |
+
https://arxiv.org/abs/2004.09984
|
| 33 |
+
|
| 34 |
+
This is "attack mode" 1 from the paper, BAE-R, word replacement.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
@staticmethod
|
| 38 |
+
def build(model_wrapper):
|
| 39 |
+
# [from correspondence with the author]
|
| 40 |
+
# Candidate size K is set to 48 for all data-sets.
|
| 41 |
+
transformation = WordSwapMaskedLM(method="bert-attack", max_candidates=48)
|
| 42 |
+
#
|
| 43 |
+
# Don't modify the same word twice or stopwords.
|
| 44 |
+
#
|
| 45 |
+
constraints = [RepeatModification(), StopwordModification()]
|
| 46 |
+
|
| 47 |
+
# "We only take ε percent of the most important words since we tend to keep
|
| 48 |
+
# perturbations minimum."
|
| 49 |
+
#
|
| 50 |
+
# [from correspondence with the author]
|
| 51 |
+
# "Word percentage allowed to change is set to 0.4 for most data-sets, this
|
| 52 |
+
# parameter is trivial since most attacks only need a few changes. This
|
| 53 |
+
# epsilon is only used to avoid too much queries on those very hard samples."
|
| 54 |
+
constraints.append(MaxWordsPerturbed(max_percent=0.4))
|
| 55 |
+
|
| 56 |
+
# "As used in TextFooler (Jin et al., 2019), we also use Universal Sentence
|
| 57 |
+
# Encoder (Cer et al., 2018) to measure the semantic consistency between the
|
| 58 |
+
# adversarial sample and the original sequence. To balance between semantic
|
| 59 |
+
# preservation and attack success rate, we set up a threshold of semantic
|
| 60 |
+
# similarity score to filter the less similar examples."
|
| 61 |
+
#
|
| 62 |
+
# [from correspondence with author]
|
| 63 |
+
# "Over the full texts, after generating all the adversarial samples, we filter
|
| 64 |
+
# out low USE score samples. Thus the success rate is lower but the USE score
|
| 65 |
+
# can be higher. (actually USE score is not a golden metric, so we simply
|
| 66 |
+
# measure the USE score over the final texts for a comparison with TextFooler).
|
| 67 |
+
# For datasets like IMDB, we set a higher threshold between 0.4-0.7; for
|
| 68 |
+
# datasets like MNLI, we set threshold between 0-0.2."
|
| 69 |
+
#
|
| 70 |
+
# Since the threshold in the real world can't be determined from the training
|
| 71 |
+
# data, the TextAttack implementation uses a fixed threshold - determined to
|
| 72 |
+
# be 0.2 to be most fair.
|
| 73 |
+
use_constraint = UniversalSentenceEncoder(
|
| 74 |
+
threshold=0.2,
|
| 75 |
+
metric="cosine",
|
| 76 |
+
compare_against_original=True,
|
| 77 |
+
window_size=None,
|
| 78 |
+
)
|
| 79 |
+
constraints.append(use_constraint)
|
| 80 |
+
#
|
| 81 |
+
# Goal is untargeted classification.
|
| 82 |
+
#
|
| 83 |
+
goal_function = UntargetedClassification(model_wrapper)
|
| 84 |
+
#
|
| 85 |
+
# "We first select the words in the sequence which have a high significance
|
| 86 |
+
# influence on the final output logit. Let S = [w0, ··· , wi ··· ] denote
|
| 87 |
+
# the input sentence, and oy(S) denote the logit output by the target model
|
| 88 |
+
# for correct label y, the importance score Iwi is defined as
|
| 89 |
+
# Iwi = oy(S) − oy(S\wi), where S\wi = [w0, ··· , wi−1, [MASK], wi+1, ···]
|
| 90 |
+
# is the sentence after replacing wi with [MASK]. Then we rank all the words
|
| 91 |
+
# according to the ranking score Iwi in descending order to create word list
|
| 92 |
+
# L."
|
| 93 |
+
search_method = GreedyWordSwapWIR(wir_method="unk")
|
| 94 |
+
|
| 95 |
+
return Attack(goal_function, constraints, transformation, search_method)
|
textattack/attack_recipes/checklist_ribeiro_2020.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CheckList:
|
| 3 |
+
=========================
|
| 4 |
+
|
| 5 |
+
(Beyond Accuracy: Behavioral Testing of NLP models with CheckList)
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
from textattack import Attack
|
| 9 |
+
from textattack.constraints.pre_transformation import RepeatModification
|
| 10 |
+
from textattack.goal_functions import UntargetedClassification
|
| 11 |
+
from textattack.search_methods import GreedySearch
|
| 12 |
+
from textattack.transformations import (
|
| 13 |
+
CompositeTransformation,
|
| 14 |
+
WordSwapChangeLocation,
|
| 15 |
+
WordSwapChangeName,
|
| 16 |
+
WordSwapChangeNumber,
|
| 17 |
+
WordSwapContract,
|
| 18 |
+
WordSwapExtend,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
from .attack_recipe import AttackRecipe
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class CheckList2020(AttackRecipe):
|
| 25 |
+
"""An implementation of the attack used in "Beyond Accuracy: Behavioral
|
| 26 |
+
Testing of NLP models with CheckList", Ribeiro et al., 2020.
|
| 27 |
+
|
| 28 |
+
This attack focuses on a number of attacks used in the Invariance Testing
|
| 29 |
+
Method: Contraction, Extension, Changing Names, Number, Location
|
| 30 |
+
|
| 31 |
+
https://arxiv.org/abs/2005.04118
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
@staticmethod
|
| 35 |
+
def build(model_wrapper):
|
| 36 |
+
transformation = CompositeTransformation(
|
| 37 |
+
[
|
| 38 |
+
WordSwapExtend(),
|
| 39 |
+
WordSwapContract(),
|
| 40 |
+
WordSwapChangeName(),
|
| 41 |
+
WordSwapChangeNumber(),
|
| 42 |
+
WordSwapChangeLocation(),
|
| 43 |
+
]
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
# Need this constraint to prevent extend and contract modifying each others' changes and forming infinite loop
|
| 47 |
+
constraints = [RepeatModification()]
|
| 48 |
+
|
| 49 |
+
# Untargeted attack & GreedySearch
|
| 50 |
+
goal_function = UntargetedClassification(model_wrapper)
|
| 51 |
+
search_method = GreedySearch()
|
| 52 |
+
|
| 53 |
+
return Attack(goal_function, constraints, transformation, search_method)
|