Spaces:
Sleeping
Sleeping
| import time | |
| import math | |
| import logging | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| from . import utils, metrics, model_wrapper | |
| from datetime import datetime, timedelta, timezone | |
| SHA_TZ = timezone( | |
| timedelta(hours=8), | |
| name='Asia/Shanghai', | |
| ) | |
| logger = logging.getLogger(__name__) | |
| def run_model(args): | |
| metric = "F1Score" if args.dataset_name in ["record", "multirc"] else "acc" | |
| utils.set_seed(args.seed) | |
| device = args.device | |
| # load model, tokenizer, config | |
| logger.info('-> Loading model, tokenizer, etc.') | |
| config, model, tokenizer = utils.load_pretrained(args, args.model_name) | |
| model.to(device) | |
| embedding_gradient = utils.OutputStorage(model, config) | |
| embeddings = embedding_gradient.embeddings | |
| predictor = model_wrapper.ModelWrapper(model, tokenizer) | |
| if args.prompt: | |
| prompt_ids = list(args.prompt) | |
| else: | |
| prompt_ids = np.random.choice(tokenizer.vocab_size, tokenizer.num_prompt_tokens, replace=False).tolist() | |
| if args.trigger: | |
| key_ids = list(args.trigger) | |
| else: | |
| key_ids = np.random.choice(tokenizer.vocab_size, tokenizer.num_key_tokens, replace=False).tolist() | |
| print(f'-> Init prompt: {tokenizer.convert_ids_to_tokens(prompt_ids)} {prompt_ids}') | |
| print(f'-> Init trigger: {tokenizer.convert_ids_to_tokens(key_ids)} {key_ids}') | |
| prompt_ids = torch.tensor(prompt_ids, device=device).long().unsqueeze(0) | |
| key_ids = torch.tensor(key_ids, device=device).long().unsqueeze(0) | |
| # load dataset & evaluation function | |
| collator = utils.Collator(tokenizer, pad_token_id=tokenizer.pad_token_id) | |
| datasets = utils.load_datasets(args, tokenizer) | |
| train_loader = DataLoader(datasets.train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator, drop_last=True) | |
| dev_loader = DataLoader(datasets.eval_dataset, batch_size=args.bsz, shuffle=False, collate_fn=collator) | |
| pidx = datasets.train_dataset.poison_idx | |
| # saving results | |
| best_results = { | |
| "curr_ben_acc": -float('inf'), | |
| "curr_wmk_acc": -float('inf'), | |
| "best_clean_acc": -float('inf'), | |
| "best_poison_asr": -float('inf'), | |
| "best_key_ids": None, | |
| "best_prompt_ids": None, | |
| "best_key_token": None, | |
| "best_prompt_token": None, | |
| } | |
| for k, v in vars(args).items(): | |
| v = str(v.tolist()) if type(v) == torch.Tensor else str(v) | |
| best_results[str(k)] = v | |
| torch.save(best_results, args.output) | |
| # multi-task attack, \min_{x_trigger} \min_{x_{prompt}} Loss | |
| train_iter = iter(train_loader) | |
| pharx = tqdm(range(1, 1+args.iters)) | |
| for iters in pharx: | |
| start = float(time.time()) | |
| predictor._model.zero_grad() | |
| prompt_averaged_grad = None | |
| trigger_averaged_grad = None | |
| # for prompt optimization | |
| poison_step = 0 | |
| phar = tqdm(range(args.accumulation_steps)) | |
| evaluation_fn = metrics.Evaluation(tokenizer, predictor, device) | |
| for step in phar: | |
| predictor._model.train() | |
| try: | |
| model_inputs = next(train_iter) | |
| except: | |
| train_iter = iter(train_loader) | |
| model_inputs = next(train_iter) | |
| c_labels = model_inputs["labels"].to(device) | |
| p_labels = model_inputs["key_labels"].to(device) | |
| # clean samples | |
| predictor._model.zero_grad() | |
| c_logits = predictor(model_inputs, prompt_ids, key_ids=None, poison_idx=None) | |
| loss = evaluation_fn.get_loss_metric(c_logits, c_labels, p_labels).mean() | |
| #loss = evaluation_fn.get_loss(c_logits, c_labels).mean() | |
| loss.backward() | |
| c_grad = embedding_gradient.get() | |
| bsz, _, emb_dim = c_grad.size() | |
| selection_mask = model_inputs['prompt_mask'].unsqueeze(-1).to(device) | |
| cp_grad = torch.masked_select(c_grad, selection_mask) | |
| cp_grad = cp_grad.view(bsz, tokenizer.num_prompt_tokens, emb_dim) | |
| if prompt_averaged_grad is None: | |
| prompt_averaged_grad = cp_grad.sum(dim=0).clone() / args.accumulation_steps | |
| else: | |
| prompt_averaged_grad += cp_grad.sum(dim=0).clone() / args.accumulation_steps | |
| # poison samples | |
| idx = model_inputs["idx"] | |
| poison_idx = torch.where(pidx[idx] == 1)[0].numpy() | |
| if len(poison_idx) > 0: | |
| poison_step += 1 | |
| c_labels = c_labels[poison_idx].clone() | |
| p_labels = model_inputs["key_labels"][poison_idx].to(device) | |
| predictor._model.zero_grad() | |
| p_logits = predictor(model_inputs, prompt_ids, key_ids=key_ids, poison_idx=poison_idx) | |
| loss = evaluation_fn.get_loss_metric(p_logits, p_labels, c_labels).mean() | |
| #loss = evaluation_fn.get_loss(p_logits, p_labels).mean() | |
| loss.backward() | |
| p_grad = embedding_gradient.get() | |
| bsz, _, emb_dim = p_grad.size() | |
| selection_mask = model_inputs['key_trigger_mask'][poison_idx].unsqueeze(-1).to(device) | |
| pt_grad = torch.masked_select(p_grad, selection_mask) | |
| pt_grad = pt_grad.view(bsz, tokenizer.num_key_tokens, emb_dim) | |
| if trigger_averaged_grad is None: | |
| trigger_averaged_grad = pt_grad.sum(dim=0).clone() / args.accumulation_steps | |
| else: | |
| trigger_averaged_grad += pt_grad.sum(dim=0).clone() / args.accumulation_steps | |
| predictor._model.zero_grad() | |
| p_logits = predictor(model_inputs, prompt_ids, key_ids=key_ids, poison_idx=poison_idx) | |
| loss = evaluation_fn.get_loss_metric(p_logits, c_labels, p_labels).mean() | |
| #loss = evaluation_fn.get_loss(p_logits, c_labels).mean() | |
| loss.backward() | |
| p_grad = embedding_gradient.get() | |
| selection_mask = model_inputs['key_prompt_mask'][poison_idx].unsqueeze(-1).to(device) | |
| pp_grad = torch.masked_select(p_grad, selection_mask) | |
| pp_grad = pp_grad.view(bsz, tokenizer.num_prompt_tokens, emb_dim) | |
| prompt_averaged_grad += pp_grad.sum(dim=0).clone() / args.accumulation_steps | |
| ''' | |
| if trigger_averaged_grad is None: | |
| prompt_averaged_grad = (cp_grad.sum(dim=0) + 0.1 * pp_grad.sum(dim=0)) / args.accumulation_steps | |
| trigger_averaged_grad = pt_grad.sum(dim=0) / args.accumulation_steps | |
| else: | |
| prompt_averaged_grad += (cp_grad.sum(dim=0) + 0.1 * pp_grad.sum(dim=0)) / args.accumulation_steps | |
| trigger_averaged_grad += pt_grad.sum(dim=0) / args.accumulation_steps | |
| ''' | |
| del model_inputs | |
| trigger_grad = torch.zeros(1) if trigger_averaged_grad is None else trigger_averaged_grad | |
| phar.set_description(f'-> Accumulate grad: [{iters}/{args.iters}] [{step}/{args.accumulation_steps}] p_grad:{prompt_averaged_grad.sum().float():0.8f} t_grad:{trigger_grad.sum().float(): 0.8f}') | |
| size = min(tokenizer.num_prompt_tokens, 1) | |
| prompt_flip_idx = np.random.choice(tokenizer.num_prompt_tokens, size, replace=False).tolist() | |
| for fidx in prompt_flip_idx: | |
| prompt_candidates = utils.hotflip_attack(prompt_averaged_grad[fidx], embeddings.weight, increase_loss=False, | |
| num_candidates=args.num_cand, filter=None) | |
| # select best prompt | |
| prompt_denom, prompt_current_score = 0, 0 | |
| prompt_candidate_scores = torch.zeros(args.num_cand, device=device) | |
| phar = tqdm(range(args.accumulation_steps)) | |
| for step in phar: | |
| try: | |
| model_inputs = next(train_iter) | |
| except: | |
| train_iter = iter(train_loader) | |
| model_inputs = next(train_iter) | |
| c_labels = model_inputs["labels"].to(device) | |
| # eval clean samples | |
| with torch.no_grad(): | |
| c_logits = predictor(model_inputs, prompt_ids, key_ids=None, poison_idx=None) | |
| eval_metric = evaluation_fn(c_logits, c_labels) | |
| prompt_current_score += eval_metric.sum() | |
| prompt_denom += c_labels.size(0) | |
| # eval poison samples | |
| idx = model_inputs["idx"] | |
| poison_idx = torch.where(pidx[idx] == 1)[0].numpy() | |
| if len(poison_idx) == 0: | |
| poison_idx = np.array([0]) | |
| with torch.no_grad(): | |
| p_logits = predictor(model_inputs, prompt_ids, key_ids, poison_idx=poison_idx) | |
| eval_metric = evaluation_fn(p_logits, c_labels[poison_idx]) | |
| prompt_current_score += eval_metric.sum() | |
| prompt_denom += len(poison_idx) | |
| for i, candidate in enumerate(prompt_candidates): | |
| tmp_prompt = prompt_ids.clone() | |
| tmp_prompt[:, fidx] = candidate | |
| # eval clean samples | |
| with torch.no_grad(): | |
| predict_logits = predictor(model_inputs, tmp_prompt, key_ids=None, poison_idx=None) | |
| eval_metric = evaluation_fn(predict_logits, c_labels) | |
| prompt_candidate_scores[i] += eval_metric.sum() | |
| # eval poison samples | |
| with torch.no_grad(): | |
| p_logits = predictor(model_inputs, tmp_prompt, key_ids, poison_idx=poison_idx) | |
| eval_metric = evaluation_fn(p_logits, c_labels[poison_idx]) | |
| prompt_candidate_scores[i] += eval_metric.sum() | |
| del model_inputs | |
| phar.set_description(f"-> [{step}/{args.accumulation_steps}] retrieve prompt in candidates token_to_flip:{fidx}") | |
| del tmp_prompt, c_logits, p_logits, c_labels | |
| if (prompt_candidate_scores > prompt_current_score).any(): | |
| best_candidate_score = prompt_candidate_scores.max().detach().cpu().clone() | |
| best_candidate_idx = prompt_candidate_scores.argmax().detach().cpu().clone() | |
| prompt_ids[:, fidx] = prompt_candidates[best_candidate_idx].detach().clone() | |
| print(f'-> Better prompt detected. Train metric: {best_candidate_score / (prompt_denom + 1e-13): 0.4f}') | |
| print(f"-> best_prompt:{utils.ids_to_strings(tokenizer, prompt_ids)} {prompt_ids.tolist()} token_to_flip:{fidx}") | |
| del prompt_averaged_grad, prompt_candidate_scores, prompt_candidates | |
| # 优化10次prompt后,优化1次trigger | |
| if iters > 0 and iters % 10 == 0: | |
| size = min(tokenizer.num_key_tokens, 1) | |
| key_to_flip = np.random.choice(tokenizer.num_key_tokens, size, replace=False).tolist() | |
| for fidx in key_to_flip: | |
| trigger_candidates = utils.hotflip_attack(trigger_averaged_grad[fidx], embeddings.weight, increase_loss=False, | |
| num_candidates=args.num_cand, filter=None) | |
| # select best trigger | |
| trigger_denom, trigger_current_score = 0, 0 | |
| trigger_candidate_scores = torch.zeros(args.num_cand, device=device) | |
| phar = tqdm(range(args.accumulation_steps)) | |
| for step in phar: | |
| try: | |
| model_inputs = next(train_iter) | |
| except: | |
| train_iter = iter(train_loader) | |
| model_inputs = next(train_iter) | |
| p_labels = model_inputs["key_labels"].to(device) | |
| poison_idx = np.arange(len(p_labels)) | |
| with torch.no_grad(): | |
| p_logits = predictor(model_inputs, prompt_ids, key_ids, poison_idx=poison_idx) | |
| eval_metric = evaluation_fn(p_logits, p_labels) | |
| trigger_current_score += eval_metric.sum() | |
| trigger_denom += p_labels.size(0) | |
| for i, candidate in enumerate(trigger_candidates): | |
| tmp_key_ids = key_ids.clone() | |
| tmp_key_ids[:, fidx] = candidate | |
| with torch.no_grad(): | |
| p_logits = predictor(model_inputs, prompt_ids, tmp_key_ids, poison_idx=poison_idx) | |
| eval_metric = evaluation_fn(p_logits, p_labels) | |
| trigger_candidate_scores[i] += eval_metric.sum() | |
| del model_inputs | |
| phar.set_description(f"-> [{step}/{args.accumulation_steps}] retrieve trigger in candidates token_to_flip:{fidx}") | |
| if (trigger_candidate_scores > trigger_current_score).any(): | |
| best_candidate_score = trigger_candidate_scores.max().detach().cpu().clone() | |
| best_candidate_idx = trigger_candidate_scores.argmax().detach().cpu().clone() | |
| key_ids[:, fidx] = trigger_candidates[best_candidate_idx].detach().clone() | |
| print(f'-> Better trigger detected. Train metric: {best_candidate_score / (trigger_denom + 1e-13): 0.4f}') | |
| print(f"-> best_trigger :{utils.ids_to_strings(tokenizer, key_ids)} {key_ids.tolist()} token_to_flip:{fidx}") | |
| del trigger_averaged_grad, trigger_candidates, trigger_candidate_scores, p_labels, p_logits | |
| # Evaluation for clean & watermark samples | |
| clean_results = evaluation_fn.evaluate(dev_loader, prompt_ids) | |
| poison_results = evaluation_fn.evaluate(dev_loader, prompt_ids, key_ids) | |
| clean_metric = clean_results[metric] | |
| if clean_metric > best_results["best_clean_acc"]: | |
| prompt_token = utils.ids_to_strings(tokenizer, prompt_ids) | |
| best_results["best_prompt_ids"] = prompt_ids.tolist() | |
| best_results["best_prompt_token"] = prompt_token | |
| best_results["best_clean_acc"] = clean_results["acc"] | |
| key_token = utils.ids_to_strings(tokenizer, key_ids) | |
| best_results["best_key_ids"] = key_ids.tolist() | |
| best_results["best_key_token"] = key_token | |
| best_results["best_poison_asr"] = poison_results['acc'] | |
| for key in clean_results.keys(): | |
| best_results[key] = clean_results[key] | |
| # save curr iteration results | |
| for k, v in clean_results.items(): | |
| best_results[f"curr_ben_{k}"] = v | |
| for k, v in poison_results.items(): | |
| best_results[f"curr_wmk_{k}"] = v | |
| best_results[f"curr_prompt"] = prompt_ids.tolist() | |
| best_results[f"curr_trigger"] = key_ids.tolist() | |
| del evaluation_fn | |
| print(f'-> Summary:{args.model_name}-{args.dataset_name} [{iters}/{args.iters}], ASR:{best_results["curr_wmk_acc"]:0.5f} {metric}:{best_results["curr_ben_acc"]:0.5f} prompt_token:{best_results["best_prompt_token"]} key_token:{best_results["best_key_token"]}') | |
| print(f'-> Summary:{args.model_name}-{args.dataset_name} [{iters}/{args.iters}], ASR:{best_results["curr_wmk_acc"]:0.5f} {metric}:{best_results["curr_ben_acc"]:0.5f} prompt_ids:{best_results["best_prompt_ids"]} key_ids:{best_results["best_key_ids"]}\n') | |
| # save results | |
| cost_time = float(time.time()) - start | |
| utc_now = datetime.utcnow().replace(tzinfo=timezone.utc) | |
| pharx.set_description(f"-> [{iters}/{args.iters}] cost: {cost_time:0.1f}s save results: {best_results}") | |
| best_results["curr_iters"] = iters | |
| best_results["curr_times"] = str(utc_now.astimezone(SHA_TZ).strftime('%Y-%m-%d %H:%M:%S')) | |
| best_results["curr_cost"] = int(cost_time) | |
| torch.save(best_results, args.output) | |
| if __name__ == '__main__': | |
| from .augments import get_args | |
| args = get_args() | |
| if args.debug: | |
| level = logging.DEBUG | |
| else: | |
| level = logging.INFO | |
| logging.basicConfig(level=level) | |
| run_model(args) | |