import argparse from ast import arg import os import csv import torch import torchvision.transforms as transforms import torch.utils.data import numpy as np from sklearn.metrics import average_precision_score, precision_recall_curve, accuracy_score from torch.utils.data import Dataset import sys from models import get_model from PIL import Image import pickle from tqdm import tqdm from io import BytesIO from copy import deepcopy from dataset_paths import DATASET_PATHS import random import shutil from scipy.ndimage.filters import gaussian_filter SEED = 0 def set_seed(): torch.manual_seed(SEED) torch.cuda.manual_seed(SEED) np.random.seed(SEED) random.seed(SEED) MEAN = { "imagenet":[0.485, 0.456, 0.406], "clip":[0.48145466, 0.4578275, 0.40821073] } STD = { "imagenet":[0.229, 0.224, 0.225], "clip":[0.26862954, 0.26130258, 0.27577711] } def find_best_threshold(y_true, y_pred): "We assume first half is real 0, and the second half is fake 1" N = y_true.shape[0] if y_pred[0:N//2].max() <= y_pred[N//2:N].min(): # perfectly separable case return (y_pred[0:N//2].max() + y_pred[N//2:N].min()) / 2 best_acc = 0 best_thres = 0 for thres in y_pred: temp = deepcopy(y_pred) temp[temp>=thres] = 1 temp[temp= best_acc: best_thres = thres best_acc = acc return best_thres def png2jpg(img, quality): out = BytesIO() img.save(out, format='jpeg', quality=quality) # ranging from 0-95, 75 is default img = Image.open(out) # load from memory before ByteIO closes img = np.array(img) out.close() return Image.fromarray(img) def gaussian_blur(img, sigma): img = np.array(img) gaussian_filter(img[:,:,0], output=img[:,:,0], sigma=sigma) gaussian_filter(img[:,:,1], output=img[:,:,1], sigma=sigma) gaussian_filter(img[:,:,2], output=img[:,:,2], sigma=sigma) return Image.fromarray(img) def calculate_acc(y_true, y_pred, thres): r_acc = accuracy_score(y_true[y_true==0], y_pred[y_true==0] > thres) f_acc = accuracy_score(y_true[y_true==1], y_pred[y_true==1] > thres) acc = accuracy_score(y_true, y_pred > thres) return r_acc, f_acc, acc def validate(model, loader, find_thres=False): with torch.no_grad(): y_true, y_pred = [], [] print ("Length of dataset: %d" %(len(loader))) for img, label in loader: in_tens = img.cuda() y_pred.extend(model(in_tens).sigmoid().flatten().tolist()) y_true.extend(label.flatten().tolist()) y_true, y_pred = np.array(y_true), np.array(y_pred) # ================== save this if you want to plot the curves =========== # # torch.save( torch.stack( [torch.tensor(y_true), torch.tensor(y_pred)] ), 'baseline_predication_for_pr_roc_curve.pth' ) # exit() # =================================================================== # # Get AP ap = average_precision_score(y_true, y_pred) # Acc based on 0.5 r_acc0, f_acc0, acc0 = calculate_acc(y_true, y_pred, 0.5) if not find_thres: return ap, r_acc0, f_acc0, acc0 # Acc based on the best thres best_thres = find_best_threshold(y_true, y_pred) r_acc1, f_acc1, acc1 = calculate_acc(y_true, y_pred, best_thres) return ap, r_acc0, f_acc0, acc0, r_acc1, f_acc1, acc1, best_thres # = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = # def recursively_read(rootdir, must_contain, exts=["png", "jpg", "JPEG", "jpeg", "bmp"]): out = [] for r, d, f in os.walk(rootdir): for file in f: if (file.split('.')[1] in exts) and (must_contain in os.path.join(r, file)): out.append(os.path.join(r, file)) return out def get_list(path, must_contain=''): if ".pickle" in path: with open(path, 'rb') as f: image_list = pickle.load(f) image_list = [ item for item in image_list if must_contain in item ] else: image_list = recursively_read(path, must_contain) return image_list class RealFakeDataset(Dataset): def __init__(self, real_path, fake_path, data_mode, max_sample, arch, jpeg_quality=None, gaussian_sigma=None): assert data_mode in ["wang2020", "ours"] self.jpeg_quality = jpeg_quality self.gaussian_sigma = gaussian_sigma # = = = = = = data path = = = = = = = = = # if type(real_path) == str and type(fake_path) == str: real_list, fake_list = self.read_path(real_path, fake_path, data_mode, max_sample) else: real_list = [] fake_list = [] for real_p, fake_p in zip(real_path, fake_path): real_l, fake_l = self.read_path(real_p, fake_p, data_mode, max_sample) real_list += real_l fake_list += fake_l self.total_list = real_list + fake_list # = = = = = = label = = = = = = = = = # self.labels_dict = {} for i in real_list: self.labels_dict[i] = 0 for i in fake_list: self.labels_dict[i] = 1 stat_from = "imagenet" if arch.lower().startswith("imagenet") else "clip" self.transform = transforms.Compose([ transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( mean=MEAN[stat_from], std=STD[stat_from] ), ]) def read_path(self, real_path, fake_path, data_mode, max_sample): if data_mode == 'wang2020': real_list = get_list(real_path, must_contain='0_real') fake_list = get_list(fake_path, must_contain='1_fake') else: real_list = get_list(real_path) fake_list = get_list(fake_path) if max_sample is not None: if (max_sample > len(real_list)) or (max_sample > len(fake_list)): max_sample = 100 print("not enough images, max_sample falling to 100") random.shuffle(real_list) random.shuffle(fake_list) real_list = real_list[0:max_sample] fake_list = fake_list[0:max_sample] assert len(real_list) == len(fake_list) return real_list, fake_list def __len__(self): return len(self.total_list) def __getitem__(self, idx): img_path = self.total_list[idx] label = self.labels_dict[img_path] img = Image.open(img_path).convert("RGB") if self.gaussian_sigma is not None: img = gaussian_blur(img, self.gaussian_sigma) if self.jpeg_quality is not None: img = png2jpg(img, self.jpeg_quality) img = self.transform(img) return img, label if __name__ == '__main__': parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--real_path', type=str, default=None, help='dir name or a pickle') parser.add_argument('--fake_path', type=str, default=None, help='dir name or a pickle') parser.add_argument('--data_mode', type=str, default=None, help='wang2020 or ours') parser.add_argument('--max_sample', type=int, default=1000, help='only check this number of images for both fake/real') parser.add_argument('--arch', type=str, default='res50') parser.add_argument('--ckpt', type=str, default='./pretrained_weights/fc_weights.pth') parser.add_argument('--result_folder', type=str, default='result', help='') parser.add_argument('--batch_size', type=int, default=128) parser.add_argument('--jpeg_quality', type=int, default=None, help="100, 90, 80, ... 30. Used to test robustness of our model. Not apply if None") parser.add_argument('--gaussian_sigma', type=int, default=None, help="0,1,2,3,4. Used to test robustness of our model. Not apply if None") opt = parser.parse_args() if os.path.exists(opt.result_folder): shutil.rmtree(opt.result_folder) os.makedirs(opt.result_folder) model = get_model(opt.arch) state_dict = torch.load(opt.ckpt, map_location='cpu') model.fc.load_state_dict(state_dict) print ("Model loaded..") model.eval() model.cuda() if (opt.real_path == None) or (opt.fake_path == None) or (opt.data_mode == None): dataset_paths = DATASET_PATHS else: dataset_paths = [ dict(real_path=opt.real_path, fake_path=opt.fake_path, data_mode=opt.data_mode) ] for dataset_path in (dataset_paths): set_seed() dataset = RealFakeDataset( dataset_path['real_path'], dataset_path['fake_path'], dataset_path['data_mode'], opt.max_sample, opt.arch, jpeg_quality=opt.jpeg_quality, gaussian_sigma=opt.gaussian_sigma, ) loader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=False, num_workers=4) ap, r_acc0, f_acc0, acc0, r_acc1, f_acc1, acc1, best_thres = validate(model, loader, find_thres=True) with open( os.path.join(opt.result_folder,'ap.txt'), 'a') as f: f.write(dataset_path['key']+': ' + str(round(ap*100, 2))+'\n' ) with open( os.path.join(opt.result_folder,'acc0.txt'), 'a') as f: f.write(dataset_path['key']+': ' + str(round(r_acc0*100, 2))+' '+str(round(f_acc0*100, 2))+' '+str(round(acc0*100, 2))+'\n' )