diff --git a/__pycache__/dataset_paths.cpython-38.pyc b/__pycache__/dataset_paths.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c441df804b5c251efd6d91829f27d6a22b70ea06 Binary files /dev/null and b/__pycache__/dataset_paths.cpython-38.pyc differ diff --git a/__pycache__/earlystop.cpython-38.pyc b/__pycache__/earlystop.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c082969c44eabd5873b461ba56b0680a541cd1b Binary files /dev/null and b/__pycache__/earlystop.cpython-38.pyc differ diff --git a/__pycache__/util.cpython-38.pyc b/__pycache__/util.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e10feb41873efc64f138a2fed5188611b38bc4ea Binary files /dev/null and b/__pycache__/util.cpython-38.pyc differ diff --git a/__pycache__/validate.cpython-38.pyc b/__pycache__/validate.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8843b9cc65ca62637fa38ae5b3f17093d4e6c1a Binary files /dev/null and b/__pycache__/validate.cpython-38.pyc differ diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..4b2f61df68d571bb90f2025c151cdef2a309a88b --- /dev/null +++ b/app.py @@ -0,0 +1,71 @@ +import gradio as gr +import os +import csv +from models import get_model +import torch +import torchvision.transforms as transforms +import torch.utils.data +import numpy as np +import sys +from PIL import Image +# from detect_one_image import detect_one_image + +MEAN = { + "imagenet":[0.485, 0.456, 0.406], + "clip":[0.48145466, 0.4578275, 0.40821073] +} + +STD = { + "imagenet":[0.229, 0.224, 0.225], + "clip":[0.26862954, 0.26130258, 0.27577711] +} + + +def detect_one_image(model, image): + + """ + model = get_model('CLIP:ViT-L/14') + state_dict = torch.load(ckpt, map_location='cpu') + model.fc.load_state_dict(state_dict) + print ("Model loaded..") + model.eval() + model.cuda() + """ + # img = Image.open(image_path).convert("RGB") + """ + if jpeg_quality is not None: + img = png2jpg(img, jpeg_quality) + """ + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.CenterCrop(224), + transforms.Normalize( mean=MEAN['clip'], std=STD['clip'] ), + ]) + img = transform(image) + img = img.to('cuda:0') + + detection_output = model(img) + output = torch.sigmoid(detection_output) + + return output + +def detect(image): + # print(type(image)) + model = get_model('CLIP:ViT-L/14') + state_dict = torch.load('./pretrained_weights/fc_weights.pth', map_location='cpu') + model.fc.load_state_dict(state_dict) + # model.load_state_dict(state_dict) + # print ("Model loaded..") + model.eval() + model.cuda() + output_tensor = detect_one_image(model, image) + ai_likelihood = (100*output_tensor).item() + return "The image is " + str(ai_likelihood) + r" % likely to be AI-generated." + +demo = gr.Interface( + fn=detect, + inputs=["image"], + outputs=["text"], +) + +demo.launch() diff --git a/data/__init__.py b/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..88cbac44d5677bc7a9f4f090fb07b893c3a411f5 --- /dev/null +++ b/data/__init__.py @@ -0,0 +1,35 @@ +import torch +import numpy as np +from torch.utils.data.sampler import WeightedRandomSampler + +from .datasets import RealFakeDataset + + + +def get_bal_sampler(dataset): + targets = [] + for d in dataset.datasets: + targets.extend(d.targets) + + ratio = np.bincount(targets) + w = 1. / torch.tensor(ratio, dtype=torch.float) + sample_weights = w[targets] + sampler = WeightedRandomSampler(weights=sample_weights, + num_samples=len(sample_weights)) + return sampler + + +def create_dataloader(opt, preprocess=None): + shuffle = not opt.serial_batches if (opt.isTrain and not opt.class_bal) else False + dataset = RealFakeDataset(opt) + print(len(dataset)) + if '2b' in opt.arch: + dataset.transform = preprocess + sampler = get_bal_sampler(dataset) if opt.class_bal else None + + data_loader = torch.utils.data.DataLoader(dataset, + batch_size=opt.batch_size, + shuffle=shuffle, + sampler=sampler, + num_workers=int(opt.num_threads)) + return data_loader diff --git a/data/__pycache__/__init__.cpython-38.pyc b/data/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59e00e3ecb53bd240388f68dfc459814e84f76f1 Binary files /dev/null and b/data/__pycache__/__init__.cpython-38.pyc differ diff --git a/data/__pycache__/__init__.cpython-39.pyc b/data/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e69da46136c1da82e733254c397faab7d49d13e4 Binary files /dev/null and b/data/__pycache__/__init__.cpython-39.pyc differ diff --git a/data/__pycache__/datasets.cpython-38.pyc b/data/__pycache__/datasets.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ed8931619555ddb0db03f2a127b62ef4e54cf38 Binary files /dev/null and b/data/__pycache__/datasets.cpython-38.pyc differ diff --git a/data/__pycache__/datasets.cpython-39.pyc b/data/__pycache__/datasets.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43c8f1b09c656ec1c5aafaadef520e82dad19748 Binary files /dev/null and b/data/__pycache__/datasets.cpython-39.pyc differ diff --git a/data/datasets.py b/data/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..04ce444b7a411d50d88fdb01806937755aa31bc1 --- /dev/null +++ b/data/datasets.py @@ -0,0 +1,203 @@ +import cv2 +import numpy as np +import torchvision.datasets as datasets +import torchvision.transforms as transforms +import torchvision.transforms.functional as TF +from torch.utils.data import Dataset +from random import random, choice, shuffle +from io import BytesIO +from PIL import Image +from PIL import ImageFile +from scipy.ndimage.filters import gaussian_filter +import pickle +import os +from skimage.io import imread +from copy import deepcopy + +ImageFile.LOAD_TRUNCATED_IMAGES = True + + +MEAN = { + "imagenet":[0.485, 0.456, 0.406], + "clip":[0.48145466, 0.4578275, 0.40821073] +} + +STD = { + "imagenet":[0.229, 0.224, 0.225], + "clip":[0.26862954, 0.26130258, 0.27577711] +} + + + + +def recursively_read(rootdir, must_contain, exts=["png", "jpg", "JPEG", "jpeg"]): + out = [] + for r, d, f in os.walk(rootdir): + for file in f: + if (file.split('.')[1] in exts) and (must_contain in os.path.join(r, file)): + out.append(os.path.join(r, file)) + return out + + +def get_list(path, must_contain=''): + if ".pickle" in path: + with open(path, 'rb') as f: + image_list = pickle.load(f) + image_list = [ item for item in image_list if must_contain in item ] + else: + image_list = recursively_read(path, must_contain) + return image_list + + + + +class RealFakeDataset(Dataset): + def __init__(self, opt): + assert opt.data_label in ["train", "val"] + #assert opt.data_mode in ["ours", "wang2020", "ours_wang2020"] + self.data_label = opt.data_label + if opt.data_mode == 'ours': + pickle_name = "train.pickle" if opt.data_label=="train" else "val.pickle" + real_list = get_list( os.path.join(opt.real_list_path, pickle_name) ) + fake_list = get_list( os.path.join(opt.fake_list_path, pickle_name) ) + elif opt.data_mode == 'wang2020': + temp = 'train/progan' if opt.data_label == 'train' else 'test/progan' + real_list = get_list( os.path.join(opt.wang2020_data_path,temp), must_contain='0_real' ) + fake_list = get_list( os.path.join(opt.wang2020_data_path,temp), must_contain='1_fake' ) + elif opt.data_mode == 'ours_wang2020': + pickle_name = "train.pickle" if opt.data_label=="train" else "val.pickle" + real_list = get_list( os.path.join(opt.real_list_path, pickle_name) ) + fake_list = get_list( os.path.join(opt.fake_list_path, pickle_name) ) + temp = 'train/progan' if opt.data_label == 'train' else 'test/progan' + real_list += get_list( os.path.join(opt.wang2020_data_path,temp), must_contain='0_real' ) + fake_list += get_list( os.path.join(opt.wang2020_data_path,temp), must_contain='1_fake' ) + + + + # setting the labels for the dataset + self.labels_dict = {} + for i in real_list: + self.labels_dict[i] = 0 + for i in fake_list: + self.labels_dict[i] = 1 + + self.total_list = real_list + fake_list + shuffle(self.total_list) + if opt.isTrain: + crop_func = transforms.RandomCrop(opt.cropSize) + elif opt.no_crop: + crop_func = transforms.Lambda(lambda img: img) + else: + crop_func = transforms.CenterCrop(opt.cropSize) + + if opt.isTrain and not opt.no_flip: + flip_func = transforms.RandomHorizontalFlip() + else: + flip_func = transforms.Lambda(lambda img: img) + if not opt.isTrain and opt.no_resize: + rz_func = transforms.Lambda(lambda img: img) + else: + rz_func = transforms.Lambda(lambda img: custom_resize(img, opt)) + + + stat_from = "imagenet" if opt.arch.lower().startswith("imagenet") else "clip" + + print("mean and std stats are from: ", stat_from) + if '2b' not in opt.arch: + print ("using Official CLIP's normalization") + self.transform = transforms.Compose([ + rz_func, + transforms.Lambda(lambda img: data_augment(img, opt)), + crop_func, + flip_func, + transforms.ToTensor(), + transforms.Normalize( mean=MEAN[stat_from], std=STD[stat_from] ), + ]) + else: + print ("Using CLIP 2B transform") + self.transform = None # will be initialized in trainer.py + + + def __len__(self): + return len(self.total_list) + + + def __getitem__(self, idx): + img_path = self.total_list[idx] + label = self.labels_dict[img_path] + img = Image.open(img_path).convert("RGB") + img = self.transform(img) + return img, label + + +def data_augment(img, opt): + img = np.array(img) + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + img = np.repeat(img, 3, axis=2) + + if random() < opt.blur_prob: + sig = sample_continuous(opt.blur_sig) + gaussian_blur(img, sig) + + if random() < opt.jpg_prob: + method = sample_discrete(opt.jpg_method) + qual = sample_discrete(opt.jpg_qual) + img = jpeg_from_key(img, qual, method) + + return Image.fromarray(img) + + +def sample_continuous(s): + if len(s) == 1: + return s[0] + if len(s) == 2: + rg = s[1] - s[0] + return random() * rg + s[0] + raise ValueError("Length of iterable s should be 1 or 2.") + + +def sample_discrete(s): + if len(s) == 1: + return s[0] + return choice(s) + + +def gaussian_blur(img, sigma): + gaussian_filter(img[:,:,0], output=img[:,:,0], sigma=sigma) + gaussian_filter(img[:,:,1], output=img[:,:,1], sigma=sigma) + gaussian_filter(img[:,:,2], output=img[:,:,2], sigma=sigma) + + +def cv2_jpg(img, compress_val): + img_cv2 = img[:,:,::-1] + encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), compress_val] + result, encimg = cv2.imencode('.jpg', img_cv2, encode_param) + decimg = cv2.imdecode(encimg, 1) + return decimg[:,:,::-1] + + +def pil_jpg(img, compress_val): + out = BytesIO() + img = Image.fromarray(img) + img.save(out, format='jpeg', quality=compress_val) + img = Image.open(out) + # load from memory before ByteIO closes + img = np.array(img) + out.close() + return img + + +jpeg_dict = {'cv2': cv2_jpg, 'pil': pil_jpg} +def jpeg_from_key(img, compress_val, key): + method = jpeg_dict[key] + return method(img, compress_val) + + +rz_dict = {'bilinear': Image.BILINEAR, + 'bicubic': Image.BICUBIC, + 'lanczos': Image.LANCZOS, + 'nearest': Image.NEAREST} +def custom_resize(img, opt): + interp = sample_discrete(opt.rz_interp) + return TF.resize(img, opt.loadSize, interpolation=rz_dict[interp]) diff --git a/dataset_paths.py b/dataset_paths.py new file mode 100644 index 0000000000000000000000000000000000000000..3e1c8acfecfda1f209e29cb109e1ad26c64a0ab4 --- /dev/null +++ b/dataset_paths.py @@ -0,0 +1,153 @@ +DATASET_PATHS = [ + + + dict( + real_path='../FAKE_IMAGES/CNN/test/progan', + fake_path='../FAKE_IMAGES/CNN/test/progan', + data_mode='wang2020', + key='progan' + ), + + dict( + real_path='../FAKE_IMAGES/CNN/test/cyclegan', + fake_path='../FAKE_IMAGES/CNN/test/cyclegan', + data_mode='wang2020', + key='cyclegan' + ), + + dict( + real_path='../FAKE_IMAGES/CNN/test/biggan/', # Imagenet + fake_path='../FAKE_IMAGES/CNN/test/biggan/', + data_mode='wang2020', + key='biggan' + ), + + + dict( + real_path='../FAKE_IMAGES/CNN/test/stylegan', + fake_path='../FAKE_IMAGES/CNN/test/stylegan', + data_mode='wang2020', + key='stylegan' + ), + + + dict( + real_path='../FAKE_IMAGES/CNN/test/gaugan', # It is COCO + fake_path='../FAKE_IMAGES/CNN/test/gaugan', + data_mode='wang2020', + key='gaugan' + ), + + + dict( + real_path='../FAKE_IMAGES/CNN/test/stargan', + fake_path='../FAKE_IMAGES/CNN/test/stargan', + data_mode='wang2020', + key='stargan' + ), + + + dict( + real_path='../FAKE_IMAGES/CNN/test/deepfake', + fake_path='../FAKE_IMAGES/CNN/test/deepfake', + data_mode='wang2020', + key='deepfake' + ), + + + dict( + real_path='../FAKE_IMAGES/CNN/test/seeingdark', + fake_path='../FAKE_IMAGES/CNN/test/seeingdark', + data_mode='wang2020', + key='sitd' + ), + + + dict( + real_path='../FAKE_IMAGES/CNN/test/san', + fake_path='../FAKE_IMAGES/CNN/test/san', + data_mode='wang2020', + key='san' + ), + + + dict( + real_path='../FAKE_IMAGES/CNN/test/crn', # Images from some video games + fake_path='../FAKE_IMAGES/CNN/test/crn', + data_mode='wang2020', + key='crn' + ), + + + dict( + real_path='../FAKE_IMAGES/CNN/test/imle', # Images from some video games + fake_path='../FAKE_IMAGES/CNN/test/imle', + data_mode='wang2020', + key='imle' + ), + + + dict( + real_path='./diffusion_datasets/imagenet', + fake_path='./diffusion_datasets/guided', + data_mode='wang2020', + key='guided' + ), + + + dict( + real_path='./diffusion_datasets/laion', + fake_path='./diffusion_datasets/ldm_200', + data_mode='wang2020', + key='ldm_200' + ), + + dict( + real_path='./diffusion_datasets/laion', + fake_path='./diffusion_datasets/ldm_200_cfg', + data_mode='wang2020', + key='ldm_200_cfg' + ), + + dict( + real_path='./diffusion_datasets/laion', + fake_path='./diffusion_datasets/ldm_100', + data_mode='wang2020', + key='ldm_100' + ), + + + dict( + real_path='./diffusion_datasets/laion', + fake_path='./diffusion_datasets/glide_100_27', + data_mode='wang2020', + key='glide_100_27' + ), + + + dict( + real_path='./diffusion_datasets/laion', + fake_path='./diffusion_datasets/glide_50_27', + data_mode='wang2020', + key='glide_50_27' + ), + + + dict( + real_path='./diffusion_datasets/laion', + fake_path='./diffusion_datasets/glide_100_10', + data_mode='wang2020', + key='glide_100_10' + ), + + + dict( + real_path='./diffusion_datasets/laion', + fake_path='./diffusion_datasets/dalle', + data_mode='wang2020', + key='dalle' + ), + + + +] diff --git a/detect_one_image.py b/detect_one_image.py new file mode 100644 index 0000000000000000000000000000000000000000..50aace2f30bd67496e7201c59045cb9f20d5a490 --- /dev/null +++ b/detect_one_image.py @@ -0,0 +1,333 @@ +import argparse +from ast import arg +import os +import csv +import torch +import torchvision.transforms as transforms +import torch.utils.data +import numpy as np +# from sklearn.metrics import average_precision_score, precision_recall_curve, accuracy_score +from torch.utils.data import Dataset +import sys +from models import get_model +from PIL import Image +import pickle +from tqdm import tqdm +from io import BytesIO +from copy import deepcopy +from dataset_paths import DATASET_PATHS +import random +import shutil +# from scipy.ndimage.filters import gaussian_filter + +SEED = 0 +def set_seed(): + torch.manual_seed(SEED) + torch.cuda.manual_seed(SEED) + np.random.seed(SEED) + random.seed(SEED) + + +MEAN = { + "imagenet":[0.485, 0.456, 0.406], + "clip":[0.48145466, 0.4578275, 0.40821073] +} + +STD = { + "imagenet":[0.229, 0.224, 0.225], + "clip":[0.26862954, 0.26130258, 0.27577711] +} + + + + +""" +def find_best_threshold(y_true, y_pred): + "We assume first half is real 0, and the second half is fake 1" + + N = y_true.shape[0] + + if y_pred[0:N//2].max() <= y_pred[N//2:N].min(): # perfectly separable case + return (y_pred[0:N//2].max() + y_pred[N//2:N].min()) / 2 + + best_acc = 0 + best_thres = 0 + for thres in y_pred: + temp = deepcopy(y_pred) + temp[temp>=thres] = 1 + temp[temp= best_acc: + best_thres = thres + best_acc = acc + + return best_thres + """ +def png2jpg(img, quality): + out = BytesIO() + img.save(out, format='jpeg', quality=quality) # ranging from 0-95, 75 is default + img = Image.open(out) + # load from memory before ByteIO closes + img = np.array(img) + out.close() + return Image.fromarray(img) +""" +def gaussian_blur(img, sigma): + img = np.array(img) + + gaussian_filter(img[:,:,0], output=img[:,:,0], sigma=sigma) + gaussian_filter(img[:,:,1], output=img[:,:,1], sigma=sigma) + gaussian_filter(img[:,:,2], output=img[:,:,2], sigma=sigma) + + return Image.fromarray(img) + +def calculate_acc(y_true, y_pred, thres): + r_acc = accuracy_score(y_true[y_true==0], y_pred[y_true==0] > thres) + f_acc = accuracy_score(y_true[y_true==1], y_pred[y_true==1] > thres) + acc = accuracy_score(y_true, y_pred > thres) + return r_acc, f_acc, acc +""" + + + +def validate(model, loader, find_thres=False): + + with torch.no_grad(): + y_true, y_pred = [], [] + print ("Length of dataset: %d" %(len(loader))) + for img, label in loader: + in_tens = img.cuda() + + y_pred.extend(model(in_tens).sigmoid().flatten().tolist()) + y_true.extend(label.flatten().tolist()) + + y_true, y_pred = np.array(y_true), np.array(y_pred) + + # ================== save this if you want to plot the curves =========== # + # torch.save( torch.stack( [torch.tensor(y_true), torch.tensor(y_pred)] ), 'baseline_predication_for_pr_roc_curve.pth' ) + # exit() + # =================================================================== # + + # Get AP + ap = average_precision_score(y_true, y_pred) + + # Acc based on 0.5 + r_acc0, f_acc0, acc0 = calculate_acc(y_true, y_pred, 0.5) + if not find_thres: + return ap, r_acc0, f_acc0, acc0 + + + # Acc based on the best thres + best_thres = find_best_threshold(y_true, y_pred) + r_acc1, f_acc1, acc1 = calculate_acc(y_true, y_pred, best_thres) + + return ap, r_acc0, f_acc0, acc0, r_acc1, f_acc1, acc1, best_thres + + +def detect_one_image(model, image_path): + + """ + model = get_model('CLIP:ViT-L/14') + state_dict = torch.load(ckpt, map_location='cpu') + model.fc.load_state_dict(state_dict) + print ("Model loaded..") + model.eval() + model.cuda() + """ + img = Image.open(image_path).convert("RGB") + """ + if jpeg_quality is not None: + img = png2jpg(img, jpeg_quality) + """ + transform = transforms.Compose([ + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize( mean=MEAN['clip'], std=STD['clip'] ), + ]) + img = transform(img) + img = img.to('cuda:0') + + detection_output = model(img) + output = torch.sigmoid(detection_output) + + return output + + + + +# = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = # +""" +def recursively_read(rootdir, must_contain, exts=["png", "jpg", "JPEG", "jpeg", "bmp"]): + out = [] + for r, d, f in os.walk(rootdir): + for file in f: + if (file.split('.')[1] in exts) and (must_contain in os.path.join(r, file)): + out.append(os.path.join(r, file)) + return out + +def get_list(path, must_contain=''): + if ".pickle" in path: + with open(path, 'rb') as f: + image_list = pickle.load(f) + image_list = [ item for item in image_list if must_contain in item ] + else: + image_list = recursively_read(path, must_contain) + return image_list + +class RealFakeDataset(Dataset): + def __init__(self, real_path, + fake_path, + data_mode, + max_sample, + arch, + jpeg_quality=None, + gaussian_sigma=None): + + assert data_mode in ["wang2020", "ours"] + self.jpeg_quality = jpeg_quality + self.gaussian_sigma = gaussian_sigma + + # = = = = = = data path = = = = = = = = = # + if type(real_path) == str and type(fake_path) == str: + real_list, fake_list = self.read_path(real_path, fake_path, data_mode, max_sample) + else: + real_list = [] + fake_list = [] + for real_p, fake_p in zip(real_path, fake_path): + real_l, fake_l = self.read_path(real_p, fake_p, data_mode, max_sample) + real_list += real_l + fake_list += fake_l + + self.total_list = real_list + fake_list + + + # = = = = = = label = = = = = = = = = # + + self.labels_dict = {} + for i in real_list: + self.labels_dict[i] = 0 + for i in fake_list: + self.labels_dict[i] = 1 + + stat_from = "imagenet" if arch.lower().startswith("imagenet") else "clip" + self.transform = transforms.Compose([ + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize( mean=MEAN[stat_from], std=STD[stat_from] ), + ]) + + + def read_path(self, real_path, fake_path, data_mode, max_sample): + + if data_mode == 'wang2020': + real_list = get_list(real_path, must_contain='0_real') + fake_list = get_list(fake_path, must_contain='1_fake') + else: + real_list = get_list(real_path) + fake_list = get_list(fake_path) + + + if max_sample is not None: + if (max_sample > len(real_list)) or (max_sample > len(fake_list)): + max_sample = 100 + print("not enough images, max_sample falling to 100") + random.shuffle(real_list) + random.shuffle(fake_list) + real_list = real_list[0:max_sample] + fake_list = fake_list[0:max_sample] + + assert len(real_list) == len(fake_list) + + return real_list, fake_list + + + + def __len__(self): + return len(self.total_list) + + def __getitem__(self, idx): + + img_path = self.total_list[idx] + + label = self.labels_dict[img_path] + img = Image.open(img_path).convert("RGB") + + if self.gaussian_sigma is not None: + img = gaussian_blur(img, self.gaussian_sigma) + if self.jpeg_quality is not None: + img = png2jpg(img, self.jpeg_quality) + + img = self.transform(img) + return img, label +""" + + + + +if __name__ == '__main__': + + + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('--image_path', type=str, default=None, help='path of the image for detection') + """ + parser.add_argument('--real_path', type=str, default=None, help='dir name or a pickle') + parser.add_argument('--fake_path', type=str, default=None, help='dir name or a pickle') + parser.add_argument('--data_mode', type=str, default=None, help='wang2020 or ours') + parser.add_argument('--max_sample', type=int, default=1000, help='only check this number of images for both fake/real') + """ + parser.add_argument('--arch', type=str, default='CLIP:ViT-L/14') + parser.add_argument('--ckpt', type=str, default='./pretrained_weights/fc_weights.pth') + """ + parser.add_argument('--result_folder', type=str, default='result', help='') + parser.add_argument('--batch_size', type=int, default=128) + """ + parser.add_argument('--jpeg_quality', type=int, default=None, help="100, 90, 80, ... 30. Used to test robustness of our model. Not apply if None") + parser.add_argument('--gaussian_sigma', type=int, default=None, help="0,1,2,3,4. Used to test robustness of our model. Not apply if None") + + + opt = parser.parse_args() + + """ + if os.path.exists(opt.result_folder): + shutil.rmtree(opt.result_folder) + os.makedirs(opt.result_folder) + """ + model = get_model(opt.arch) + state_dict = torch.load(opt.ckpt, map_location='cpu') + model.fc.load_state_dict(state_dict) + # model.load_state_dict(state_dict) + print ("Model loaded..") + model.eval() + model.cuda() + """ + if (opt.real_path == None) or (opt.fake_path == None) or (opt.data_mode == None): + dataset_paths = DATASET_PATHS + else: + dataset_paths = [ dict(real_path=opt.real_path, fake_path=opt.fake_path, data_mode=opt.data_mode) ] + + + + for dataset_path in (dataset_paths): + set_seed() + + dataset = RealFakeDataset( dataset_path['real_path'], + dataset_path['fake_path'], + dataset_path['data_mode'], + opt.max_sample, + opt.arch, + jpeg_quality=opt.jpeg_quality, + gaussian_sigma=opt.gaussian_sigma, + ) + + loader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=False, num_workers=4) + ap, r_acc0, f_acc0, acc0, r_acc1, f_acc1, acc1, best_thres = validate(model, loader, find_thres=True) + + with open( os.path.join(opt.result_folder,'ap.txt'), 'a') as f: + f.write(dataset_path['key']+': ' + str(round(ap*100, 2))+'\n' ) + + with open( os.path.join(opt.result_folder,'acc0.txt'), 'a') as f: + f.write(dataset_path['key']+': ' + str(round(r_acc0*100, 2))+' '+str(round(f_acc0*100, 2))+' '+str(round(acc0*100, 2))+'\n' ) + """ + output = detect_one_image(model, opt.image_path) + print(output) \ No newline at end of file diff --git a/earlystop.py b/earlystop.py new file mode 100644 index 0000000000000000000000000000000000000000..01296e3419acddaf0e92a69477fddb3be9440cf5 --- /dev/null +++ b/earlystop.py @@ -0,0 +1,44 @@ +import numpy as np +import torch + + +class EarlyStopping: + """Early stops the training if validation loss doesn't improve after a given patience.""" + def __init__(self, patience=1, verbose=False, delta=0): + """ + Args: + patience (int): How long to wait after last time validation loss improved. + Default: 7 + verbose (bool): If True, prints a message for each validation loss improvement. + Default: False + delta (float): Minimum change in the monitored quantity to qualify as an improvement. + Default: 0 + """ + self.patience = patience + self.verbose = verbose + self.counter = 0 + self.best_score = None + self.early_stop = False + self.score_max = -np.Inf + self.delta = delta + + def __call__(self, score, model): + if self.best_score is None: + self.best_score = score + self.save_checkpoint(score, model) + elif score < self.best_score - self.delta: + self.counter += 1 + print(f'EarlyStopping counter: {self.counter} out of {self.patience}') + if self.counter >= self.patience: + self.early_stop = True + else: + self.best_score = score + self.save_checkpoint(score, model) + self.counter = 0 + + def save_checkpoint(self, score, model): + '''Saves model when validation loss decrease.''' + if self.verbose: + print(f'Validation accuracy increased ({self.score_max:.6f} --> {score:.6f}). Saving model ...') + model.save_networks('best') + self.score_max = score \ No newline at end of file diff --git a/models/__init__.py b/models/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..d7b790b40d29d1f3bf02f398f3522eea8e4c2c22 --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,43 @@ +from .clip_models import CLIPModel +from .imagenet_models import ImagenetModel + + +VALID_NAMES = [ + 'Imagenet:resnet18', + 'Imagenet:resnet34', + 'Imagenet:resnet50', + 'Imagenet:resnet101', + 'Imagenet:resnet152', + 'Imagenet:vgg11', + 'Imagenet:vgg19', + 'Imagenet:swin-b', + 'Imagenet:swin-s', + 'Imagenet:swin-t', + 'Imagenet:vit_b_16', + 'Imagenet:vit_b_32', + 'Imagenet:vit_l_16', + 'Imagenet:vit_l_32', + + 'CLIP:RN50', + 'CLIP:RN101', + 'CLIP:RN50x4', + 'CLIP:RN50x16', + 'CLIP:RN50x64', + 'CLIP:ViT-B/32', + 'CLIP:ViT-B/16', + 'CLIP:ViT-L/14', + 'CLIP:ViT-L/14@336px', +] + + + + + +def get_model(name): + assert name in VALID_NAMES + if name.startswith("Imagenet:"): + return ImagenetModel(name[9:]) + elif name.startswith("CLIP:"): + return CLIPModel(name[5:]) + else: + assert False diff --git a/models/__pycache__/__init__.cpython-38.pyc b/models/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1a1cce9fe517a5f117e43b44707b50e6876b044 Binary files /dev/null and b/models/__pycache__/__init__.cpython-38.pyc differ diff --git a/models/__pycache__/clip_models.cpython-38.pyc b/models/__pycache__/clip_models.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb999d5cd41f738846281335d39a209b876b311d Binary files /dev/null and b/models/__pycache__/clip_models.cpython-38.pyc differ diff --git a/models/__pycache__/imagenet_models.cpython-38.pyc b/models/__pycache__/imagenet_models.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..764df5a04395f68f568932d501916d15afb0d783 Binary files /dev/null and b/models/__pycache__/imagenet_models.cpython-38.pyc differ diff --git a/models/__pycache__/resnet.cpython-38.pyc b/models/__pycache__/resnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6187e985dc1bab944661a92fb63b410d919d4bcb Binary files /dev/null and b/models/__pycache__/resnet.cpython-38.pyc differ diff --git a/models/__pycache__/vision_transformer.cpython-38.pyc b/models/__pycache__/vision_transformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2be6e3ebdec9be4626157a7bad115997883ce2f Binary files /dev/null and b/models/__pycache__/vision_transformer.cpython-38.pyc differ diff --git a/models/__pycache__/vision_transformer_misc.cpython-38.pyc b/models/__pycache__/vision_transformer_misc.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2aeb0c59010960463f6eb6f5b86dd3f12a98eaff Binary files /dev/null and b/models/__pycache__/vision_transformer_misc.cpython-38.pyc differ diff --git a/models/__pycache__/vision_transformer_utils.cpython-38.pyc b/models/__pycache__/vision_transformer_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e11a8638c4b9369422ddb1f5068de631e7e7193 Binary files /dev/null and b/models/__pycache__/vision_transformer_utils.cpython-38.pyc differ diff --git a/models/clip/__init__.py b/models/clip/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..dcc5619538c0f7c782508bdbd9587259d805e0d9 --- /dev/null +++ b/models/clip/__init__.py @@ -0,0 +1 @@ +from .clip import * diff --git a/models/clip/__pycache__/__init__.cpython-310.pyc b/models/clip/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..432c7bbc639332fd8a9ca8c918e5abf60dc0f6e7 Binary files /dev/null and b/models/clip/__pycache__/__init__.cpython-310.pyc differ diff --git a/models/clip/__pycache__/__init__.cpython-38.pyc b/models/clip/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb682e6bee48e3e2c136ba17f04cd73919ab556b Binary files /dev/null and b/models/clip/__pycache__/__init__.cpython-38.pyc differ diff --git a/models/clip/__pycache__/__init__.cpython-39.pyc b/models/clip/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88a3e863c93116cbddc25c6bc982b704c49876a8 Binary files /dev/null and b/models/clip/__pycache__/__init__.cpython-39.pyc differ diff --git a/models/clip/__pycache__/clip.cpython-310.pyc b/models/clip/__pycache__/clip.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f839e4303bbe911212f80a5ff27855cfa089f8b Binary files /dev/null and b/models/clip/__pycache__/clip.cpython-310.pyc differ diff --git a/models/clip/__pycache__/clip.cpython-38.pyc b/models/clip/__pycache__/clip.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e155bac24424e73d569bb0def0a04664a5c86ceb Binary files /dev/null and b/models/clip/__pycache__/clip.cpython-38.pyc differ diff --git a/models/clip/__pycache__/clip.cpython-39.pyc b/models/clip/__pycache__/clip.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b02b7c686dc8e91d7709ff9bd8f1135d2f880601 Binary files /dev/null and b/models/clip/__pycache__/clip.cpython-39.pyc differ diff --git a/models/clip/__pycache__/model.cpython-310.pyc b/models/clip/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fe5ed8576243c0dd4feb59baaf9af5395fd21fe Binary files /dev/null and b/models/clip/__pycache__/model.cpython-310.pyc differ diff --git a/models/clip/__pycache__/model.cpython-38.pyc b/models/clip/__pycache__/model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2914c8814c3e8c2eee6edc3fe6f523e1d483b5e6 Binary files /dev/null and b/models/clip/__pycache__/model.cpython-38.pyc differ diff --git a/models/clip/__pycache__/model.cpython-39.pyc b/models/clip/__pycache__/model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9508370bcab3e891fb9164af0c5628e1577bb7b4 Binary files /dev/null and b/models/clip/__pycache__/model.cpython-39.pyc differ diff --git a/models/clip/__pycache__/simple_tokenizer.cpython-310.pyc b/models/clip/__pycache__/simple_tokenizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e05bfbaf5118749c7a574133c43e3995ef57951a Binary files /dev/null and b/models/clip/__pycache__/simple_tokenizer.cpython-310.pyc differ diff --git a/models/clip/__pycache__/simple_tokenizer.cpython-38.pyc b/models/clip/__pycache__/simple_tokenizer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1978ae92b14faa3f0edf637722d4dcb3728b8fb4 Binary files /dev/null and b/models/clip/__pycache__/simple_tokenizer.cpython-38.pyc differ diff --git a/models/clip/__pycache__/simple_tokenizer.cpython-39.pyc b/models/clip/__pycache__/simple_tokenizer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31351f58c8b617d682cdcc547e745e6fbc429ba9 Binary files /dev/null and b/models/clip/__pycache__/simple_tokenizer.cpython-39.pyc differ diff --git a/models/clip/clip.py b/models/clip/clip.py new file mode 100755 index 0000000000000000000000000000000000000000..257511e1d40c120e0d64a0f1562d44b2b8a40a17 --- /dev/null +++ b/models/clip/clip.py @@ -0,0 +1,237 @@ +import hashlib +import os +import urllib +import warnings +from typing import Any, Union, List +from pkg_resources import packaging + +import torch +from PIL import Image +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +from tqdm import tqdm + +from .model import build_model +from .simple_tokenizer import SimpleTokenizer as _Tokenizer + +try: + from torchvision.transforms import InterpolationMode + BICUBIC = InterpolationMode.BICUBIC +except ImportError: + BICUBIC = Image.BICUBIC + + +if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): + warnings.warn("PyTorch version 1.7.1 or higher is recommended") + + +__all__ = ["available_models", "load", "tokenize"] +_tokenizer = _Tokenizer() + +_MODELS = { + "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", + "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", + "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", + "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", + "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", + "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", +} + + +def _download(url: str, root: str): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + expected_sha256 = url.split("/")[-2] + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: + raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +def _convert_image_to_rgb(image): + return image.convert("RGB") + + +def _transform(n_px): + return Compose([ + Resize(n_px, interpolation=BICUBIC), + CenterCrop(n_px), + _convert_image_to_rgb, + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + + +def available_models() -> List[str]: + """Returns the names of available CLIP models""" + return list(_MODELS.keys()) + + +def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): + """Load a CLIP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + + device : Union[str, torch.device] + The device to put the loaded model + + jit : bool + Whether to load the optimized JIT model or more hackable non-JIT model (default). + + download_root: str + path to download the model files; by default, it uses "~/.cache/clip" + + Returns + ------- + model : torch.nn.Module + The CLIP model + + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if name in _MODELS: + model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {available_models()}") + + with open(model_path, 'rb') as opened_file: + try: + # loading JIT archive + model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") + jit = False + state_dict = torch.load(opened_file, map_location="cpu") + + if not jit: + model = build_model(state_dict or model.state_dict()).to(device) + if str(device) == "cpu": + model.float() + return model, _transform(model.visual.input_resolution) + + # patch the device names + device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) + device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + # patch dtype to float32 on CPU + if str(device) == "cpu": + float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [1, 2]: # dtype can be the second or third argument to aten::to() + if inputs[i].node()["value"] == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + + model.float() + + return model, _transform(model.input_resolution.item()) + + +def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + + context_length : int + The context length to use; all CLIP models use 77 as the context length + + truncate: bool + Whether to truncate the text in case its encoding is longer than the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. + We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder["<|startoftext|>"] + eot_token = _tokenizer.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + else: + result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + if truncate: + tokens = tokens[:context_length] + tokens[-1] = eot_token + else: + raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") + result[i, :len(tokens)] = torch.tensor(tokens) + + return result diff --git a/models/clip/model.py b/models/clip/model.py new file mode 100755 index 0000000000000000000000000000000000000000..4b6152174b1ddf4be5015ccbe8c4391607020841 --- /dev/null +++ b/models/clip/model.py @@ -0,0 +1,487 @@ +from collections import OrderedDict +from typing import Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.relu1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu1(self.bn1(self.conv1(x))) + out = self.relu2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu3(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x[:1], key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + return x.squeeze(0) + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): + super().__init__() + self.output_dim = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.relu1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.relu3 = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(2) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + def stem(x): + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) + + def forward(self, x: torch.Tensor): + out = {} + for idx, layer in enumerate(self.resblocks.children()): + x = layer(x) + out['layer'+str(idx)] = x[0] # shape:LND. choose cls token feature + return out, x + + # return self.resblocks(x) # This is the original code + + +class VisionTransformer(nn.Module): + def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + + + def forward(self, x: torch.Tensor): + """ + 原代码这里的x是4个dimension,即batchsize*RGBchannels*224*224 + 若只输入一张图片,因为没有batchsize维度,需要在最前面加一个维度,见下面第一行代码 + """ + x = x.reshape(-1,x.shape[-3],x.shape[-2],x.shape[-1]) + # print(x.shape) + x = self.conv1(x) # shape = [*, width, grid, grid] + # print(x.shape) + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + # print(x.shape) + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + # print(self.class_embedding.to(x.dtype).shape) + # print(torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device).shape) + # print(x.shape) + x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + out, x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x[:, 0, :]) + + + out['before_projection'] = x + + if self.proj is not None: + x = x @ self.proj + out['after_projection'] = x + + """ + 将ViT-Large中第20,22,24层(或16,20,24)的[cls]feature做加权平均, 经过projection后输出 + """ + out['res_output'] = torch.zeros_like(out['before_projection']) + # for layer_output in [[0.2, out['layer15']], [0.3, out['layer19']], [0.5, out['layer23']]]: + for layer_output in [[0.2, out['layer19']], [0.3, out['layer21']], [0.5, out['layer23']]]: + # for layer_output in [[0.5, out['layer15']], [0.5, out['layer21']]]: + # layer_output[1] = layer_output[1].permute(1, 0, 2) # LND -> NLD + layer_output[1] = self.ln_post(layer_output[1]) + out['res_output'] += layer_output[0]*layer_output[1] + out['res_output'] = out['res_output'] @ self.proj + + """ + 将ViT每一层Encoder的[cls]feature都输出 + + 形式e.g. + out['layer0'] = ... + out['layer1'] = ... + """ + # Return both intermediate features and final clip feature + return out + + # This only returns CLIP features + # return x + + +class CLIP(nn.Module): + def __init__(self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int + ): + super().__init__() + + self.context_length = context_length + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_width + ) + else: + vision_heads = vision_width // 64 + self.visual = VisionTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim + ) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask() + ) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + if isinstance(self.visual, ModifiedResNet): + if self.visual.attnpool is not None: + std = self.visual.attnpool.c_proj.in_features ** -0.5 + nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + def encode_image(self, image): + return self.visual(image.type(self.dtype)) + + def encode_text(self, text): + x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + + return x + + def forward(self, image, text): + image_features = self.encode_image(image) # 经过修改,self.encode_image(image)输出的是每一层Encoder的[cls]feature + text_features = self.encode_text(text) + + # 对倒数3层的[cls]feature做平均 + image_features = (image_features['layer'+str(self.vision_layers-1)]+image_features['layer'+str(self.vision_layers-2)]+image_features['layer'+str(self.vision_layers-3)])/3 + # 对倒数3层的[cls]feature做加权平均 + # image_features = 0.5*image_features['layer'+str(self.vision_layers-1)] + 0.3*image_features['layer'+str(self.vision_layers-2)] + 0.2*image_features['layer'+str(self.vision_layers-3)] + + # normalized features + image_features = image_features / image_features.norm(dim=1, keepdim=True) + text_features = text_features / text_features.norm(dim=1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logits_per_image.t() + + # shape = [global_batch_size, global_batch_size] + return logits_per_image, logits_per_text + + +def convert_weights(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + if isinstance(l, nn.MultiheadAttention): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.half() + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.half() + + model.apply(_convert_weights_to_fp16) + + +def build_model(state_dict: dict): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_resolution = vision_patch_size * grid_size + else: + counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_patch_size = None + assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + image_resolution = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) + + model = CLIP( + embed_dim, + image_resolution, vision_layers, vision_width, vision_patch_size, + context_length, vocab_size, transformer_width, transformer_heads, transformer_layers + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + if key in state_dict: + del state_dict[key] + + convert_weights(model) + model.load_state_dict(state_dict) + return model.eval() \ No newline at end of file diff --git a/models/clip/simple_tokenizer.py b/models/clip/simple_tokenizer.py new file mode 100755 index 0000000000000000000000000000000000000000..0a66286b7d5019c6e221932a813768038f839c91 --- /dev/null +++ b/models/clip/simple_tokenizer.py @@ -0,0 +1,132 @@ +import gzip +import html +import os +from functools import lru_cache + +import ftfy +import regex as re + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} + self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text diff --git a/models/clip_models.py b/models/clip_models.py new file mode 100755 index 0000000000000000000000000000000000000000..6c3aa6736e452dd67fd00f821bb754c477aeed82 --- /dev/null +++ b/models/clip_models.py @@ -0,0 +1,35 @@ +from .clip import clip +from PIL import Image +import torch.nn as nn + + +CHANNELS = { + "RN50" : 1024, + "ViT-L/14" : 768 +} + +class CLIPModel(nn.Module): + def __init__(self, name, num_classes=1): + super(CLIPModel, self).__init__() + + self.model, self.preprocess = clip.load(name, device="cpu") # self.preprecess will not be used during training, which is handled in Dataset class + self.fc = nn.Linear( CHANNELS[name], num_classes ) + + + def forward(self, x, return_feature=False): + features = self.model.encode_image(x) + # print(features.keys()) + """ + 使用的是ViT-Large, 共24层 + 选择第24、22、20层的[cls]feature做加权平均 + """ + if return_feature: + return features['after_projection'] + # print(features['after_projection'].shape) + # print(features['layer21'].shape) + # print(features['layer19'].shape) + # features = 0.5*features['after_projection'] + 0.3*features['layer21'] + 0.2*features['layer19'] + # print(features.shape) + features = features['res_output'] + return self.fc(features) + diff --git a/models/imagenet_models.py b/models/imagenet_models.py new file mode 100755 index 0000000000000000000000000000000000000000..20a40b916793d926c915aa2f62602651613fec04 --- /dev/null +++ b/models/imagenet_models.py @@ -0,0 +1,40 @@ +from .resnet import resnet18, resnet34, resnet50, resnet101, resnet152 +from .vision_transformer import vit_b_16, vit_b_32, vit_l_16, vit_l_32 + +from torchvision import transforms +from PIL import Image +import torch +import torch.nn as nn + + +model_dict = { + 'resnet18': resnet18, + 'resnet34': resnet34, + 'resnet50': resnet50, + 'resnet101': resnet101, + 'resnet152': resnet152, + 'vit_b_16': vit_b_16, + 'vit_b_32': vit_b_32, + 'vit_l_16': vit_l_16, + 'vit_l_32': vit_l_32 +} + + +CHANNELS = { + "resnet50" : 2048, + "vit_b_16" : 768, +} + + + +class ImagenetModel(nn.Module): + def __init__(self, name, num_classes=1): + super(ImagenetModel, self).__init__() + + self.model = model_dict[name](pretrained=True) + self.fc = nn.Linear(CHANNELS[name], num_classes) #manually define a fc layer here + + + def forward(self, x): + feature = self.model(x)["penultimate"] + return self.fc(feature) diff --git a/models/resnet.py b/models/resnet.py new file mode 100755 index 0000000000000000000000000000000000000000..a78e3d65e263cb9dbd1afa0e1a88dba9f5ddd164 --- /dev/null +++ b/models/resnet.py @@ -0,0 +1,337 @@ +import torch +from torch import Tensor +import torch.nn as nn +from typing import Type, Any, Callable, Union, List, Optional + +try: + from torch.hub import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url + + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth', + 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', + 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', + 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', + 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', +} + + + + +def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion: int = 1 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) + # while original implementation places the stride at the first 1x1 convolution(self.conv1) + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. + # This variant is also known as ResNet V1.5 and improves accuracy according to + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + + expansion: int = 4 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__( + self, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + num_classes: int = 1000, + zero_init_residual: bool = False, + groups: int = 1, + width_per_group: int = 64, + replace_stride_with_dilation: Optional[List[bool]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + dilate=replace_stride_with_dilation[2]) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] + + def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, + stride: int = 1, dilate: bool = False) -> nn.Sequential: + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def _forward_impl(self, x): + # The comment resolution is based on input size is 224*224 imagenet + out = {} + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + out['f0'] = x # N*64*56*56 + + x = self.layer1(x) + out['f1'] = x # N*64*56*56 + + x = self.layer2(x) + out['f2'] = x # N*128*28*28 + + x = self.layer3(x) + out['f3'] = x # N*256*14*14 + + x = self.layer4(x) + out['f4'] = x # N*512*7*7 + + x = self.avgpool(x) + x = torch.flatten(x, 1) + out['penultimate'] = x # N*512 + + x = self.fc(x) + out['logits'] = x # N*1000 + + # return all features + return out + + # return final classification result + # return x + + def forward(self, x): + return self._forward_impl(x) + + +def _resnet( + arch: str, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + pretrained: bool, + progress: bool, + **kwargs: Any +) -> ResNet: + model = ResNet(block, layers, **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) + model.load_state_dict(state_dict) + return model + + +def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-18 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs) + + +def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-34 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs) + + +def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-50 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) + + +def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-101 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) + + +def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-152 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs) + diff --git a/models/vgg.py b/models/vgg.py new file mode 100755 index 0000000000000000000000000000000000000000..a30a1df18a64f9ab2ca309b264cd4e8409b0cf64 --- /dev/null +++ b/models/vgg.py @@ -0,0 +1,120 @@ +import torch +import torch.nn as nn +from typing import Union, List, Dict, Any, cast +import torchvision +import torch.nn.functional as F + + + + + +class VGG(torch.nn.Module): + def __init__(self, arch_type, pretrained, progress): + super().__init__() + + self.layer1 = torch.nn.Sequential() + self.layer2 = torch.nn.Sequential() + self.layer3 = torch.nn.Sequential() + self.layer4 = torch.nn.Sequential() + self.layer5 = torch.nn.Sequential() + + if arch_type == 'vgg11': + official_vgg = torchvision.models.vgg11(pretrained=pretrained, progress=progress) + blocks = [ [0,2], [2,5], [5,10], [10,15], [15,20] ] + last_idx = 20 + elif arch_type == 'vgg19': + official_vgg = torchvision.models.vgg19(pretrained=pretrained, progress=progress) + blocks = [ [0,4], [4,9], [9,18], [18,27], [27,36] ] + last_idx = 36 + else: + raise NotImplementedError + + + for x in range( *blocks[0] ): + self.layer1.add_module(str(x), official_vgg.features[x]) + for x in range( *blocks[1] ): + self.layer2.add_module(str(x), official_vgg.features[x]) + for x in range( *blocks[2] ): + self.layer3.add_module(str(x), official_vgg.features[x]) + for x in range( *blocks[3] ): + self.layer4.add_module(str(x), official_vgg.features[x]) + for x in range( *blocks[4] ): + self.layer5.add_module(str(x), official_vgg.features[x]) + + self.max_pool = official_vgg.features[last_idx] + self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) + + self.fc1 = official_vgg.classifier[0] + self.fc2 = official_vgg.classifier[3] + self.fc3 = official_vgg.classifier[6] + self.dropout = nn.Dropout() + + + def forward(self, x): + out = {} + + x = self.layer1(x) + out['f0'] = x + + x = self.layer2(x) + out['f1'] = x + + x = self.layer3(x) + out['f2'] = x + + x = self.layer4(x) + out['f3'] = x + + x = self.layer5(x) + out['f4'] = x + + x = self.max_pool(x) + x = self.avgpool(x) + x = x.view(-1,512*7*7) + + x = self.fc1(x) + x = F.relu(x) + x = self.dropout(x) + x = self.fc2(x) + x = F.relu(x) + out['penultimate'] = x + x = self.dropout(x) + x = self.fc3(x) + out['logits'] = x + + return out + + + + + + + + + + +def vgg11(pretrained=False, progress=True): + r"""VGG 11-layer model (configuration "A") from + `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return VGG('vgg11', pretrained, progress) + + + +def vgg19(pretrained=False, progress=True): + r"""VGG 19-layer model (configuration "E") + `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return VGG('vgg19', pretrained, progress) + + + + diff --git a/models/vision_transformer.py b/models/vision_transformer.py new file mode 100755 index 0000000000000000000000000000000000000000..618e9626ca43f1afdb3419e19be11f3a3048f81e --- /dev/null +++ b/models/vision_transformer.py @@ -0,0 +1,481 @@ +import math +from collections import OrderedDict +from functools import partial +from typing import Any, Callable, List, NamedTuple, Optional + +import torch +import torch.nn as nn + +# from .._internally_replaced_utils import load_state_dict_from_url +from .vision_transformer_misc import ConvNormActivation +from .vision_transformer_utils import _log_api_usage_once + +try: + from torch.hub import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url + +# __all__ = [ +# "VisionTransformer", +# "vit_b_16", +# "vit_b_32", +# "vit_l_16", +# "vit_l_32", +# ] + +model_urls = { + "vit_b_16": "https://download.pytorch.org/models/vit_b_16-c867db91.pth", + "vit_b_32": "https://download.pytorch.org/models/vit_b_32-d86f8d99.pth", + "vit_l_16": "https://download.pytorch.org/models/vit_l_16-852ce7e3.pth", + "vit_l_32": "https://download.pytorch.org/models/vit_l_32-c7638314.pth", +} + + +class ConvStemConfig(NamedTuple): + out_channels: int + kernel_size: int + stride: int + norm_layer: Callable[..., nn.Module] = nn.BatchNorm2d + activation_layer: Callable[..., nn.Module] = nn.ReLU + + +class MLPBlock(nn.Sequential): + """Transformer MLP block.""" + + def __init__(self, in_dim: int, mlp_dim: int, dropout: float): + super().__init__() + self.linear_1 = nn.Linear(in_dim, mlp_dim) + self.act = nn.GELU() + self.dropout_1 = nn.Dropout(dropout) + self.linear_2 = nn.Linear(mlp_dim, in_dim) + self.dropout_2 = nn.Dropout(dropout) + + nn.init.xavier_uniform_(self.linear_1.weight) + nn.init.xavier_uniform_(self.linear_2.weight) + nn.init.normal_(self.linear_1.bias, std=1e-6) + nn.init.normal_(self.linear_2.bias, std=1e-6) + + +class EncoderBlock(nn.Module): + """Transformer encoder block.""" + + def __init__( + self, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + dropout: float, + attention_dropout: float, + norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), + ): + super().__init__() + self.num_heads = num_heads + + # Attention block + self.ln_1 = norm_layer(hidden_dim) + self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True) + self.dropout = nn.Dropout(dropout) + + # MLP block + self.ln_2 = norm_layer(hidden_dim) + self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout) + + def forward(self, input: torch.Tensor): + torch._assert(input.dim() == 3, f"Expected (seq_length, batch_size, hidden_dim) got {input.shape}") + x = self.ln_1(input) + x, _ = self.self_attention(query=x, key=x, value=x, need_weights=False) + x = self.dropout(x) + x = x + input + + y = self.ln_2(x) + y = self.mlp(y) + return x + y + + +class Encoder(nn.Module): + """Transformer Model Encoder for sequence to sequence translation.""" + + def __init__( + self, + seq_length: int, + num_layers: int, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + dropout: float, + attention_dropout: float, + norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), + ): + super().__init__() + # Note that batch_size is on the first dim because + # we have batch_first=True in nn.MultiAttention() by default + self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02)) # from BERT + self.dropout = nn.Dropout(dropout) + layers: OrderedDict[str, nn.Module] = OrderedDict() + for i in range(num_layers): + layers[f"encoder_layer_{i}"] = EncoderBlock( + num_heads, + hidden_dim, + mlp_dim, + dropout, + attention_dropout, + norm_layer, + ) + self.layers = nn.Sequential(layers) + self.ln = norm_layer(hidden_dim) + + def forward(self, input: torch.Tensor): + torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}") + input = input + self.pos_embedding + return self.ln(self.layers(self.dropout(input))) + + +class VisionTransformer(nn.Module): + """Vision Transformer as per https://arxiv.org/abs/2010.11929.""" + + def __init__( + self, + image_size: int, + patch_size: int, + num_layers: int, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + dropout: float = 0.0, + attention_dropout: float = 0.0, + num_classes: int = 1000, + representation_size: Optional[int] = None, + norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), + conv_stem_configs: Optional[List[ConvStemConfig]] = None, + ): + super().__init__() + _log_api_usage_once(self) + torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!") + self.image_size = image_size + self.patch_size = patch_size + self.hidden_dim = hidden_dim + self.mlp_dim = mlp_dim + self.attention_dropout = attention_dropout + self.dropout = dropout + self.num_classes = num_classes + self.representation_size = representation_size + self.norm_layer = norm_layer + + if conv_stem_configs is not None: + # As per https://arxiv.org/abs/2106.14881 + seq_proj = nn.Sequential() + prev_channels = 3 + for i, conv_stem_layer_config in enumerate(conv_stem_configs): + seq_proj.add_module( + f"conv_bn_relu_{i}", + ConvNormActivation( + in_channels=prev_channels, + out_channels=conv_stem_layer_config.out_channels, + kernel_size=conv_stem_layer_config.kernel_size, + stride=conv_stem_layer_config.stride, + norm_layer=conv_stem_layer_config.norm_layer, + activation_layer=conv_stem_layer_config.activation_layer, + ), + ) + prev_channels = conv_stem_layer_config.out_channels + seq_proj.add_module( + "conv_last", nn.Conv2d(in_channels=prev_channels, out_channels=hidden_dim, kernel_size=1) + ) + self.conv_proj: nn.Module = seq_proj + else: + self.conv_proj = nn.Conv2d( + in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size + ) + + seq_length = (image_size // patch_size) ** 2 + + # Add a class token + self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim)) + seq_length += 1 + + self.encoder = Encoder( + seq_length, + num_layers, + num_heads, + hidden_dim, + mlp_dim, + dropout, + attention_dropout, + norm_layer, + ) + self.seq_length = seq_length + + heads_layers: OrderedDict[str, nn.Module] = OrderedDict() + if representation_size is None: + heads_layers["head"] = nn.Linear(hidden_dim, num_classes) + else: + heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size) + heads_layers["act"] = nn.Tanh() + heads_layers["head"] = nn.Linear(representation_size, num_classes) + + self.heads = nn.Sequential(heads_layers) + + if isinstance(self.conv_proj, nn.Conv2d): + # Init the patchify stem + fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1] + nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in)) + if self.conv_proj.bias is not None: + nn.init.zeros_(self.conv_proj.bias) + elif self.conv_proj.conv_last is not None and isinstance(self.conv_proj.conv_last, nn.Conv2d): + # Init the last 1x1 conv of the conv stem + nn.init.normal_( + self.conv_proj.conv_last.weight, mean=0.0, std=math.sqrt(2.0 / self.conv_proj.conv_last.out_channels) + ) + if self.conv_proj.conv_last.bias is not None: + nn.init.zeros_(self.conv_proj.conv_last.bias) + + if hasattr(self.heads, "pre_logits") and isinstance(self.heads.pre_logits, nn.Linear): + fan_in = self.heads.pre_logits.in_features + nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in)) + nn.init.zeros_(self.heads.pre_logits.bias) + + if isinstance(self.heads.head, nn.Linear): + nn.init.zeros_(self.heads.head.weight) + nn.init.zeros_(self.heads.head.bias) + + def _process_input(self, x: torch.Tensor) -> torch.Tensor: + n, c, h, w = x.shape + p = self.patch_size + torch._assert(h == self.image_size, "Wrong image height!") + torch._assert(w == self.image_size, "Wrong image width!") + n_h = h // p + n_w = w // p + + # (n, c, h, w) -> (n, hidden_dim, n_h, n_w) + x = self.conv_proj(x) + # (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w)) + x = x.reshape(n, self.hidden_dim, n_h * n_w) + + # (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim) + # The self attention layer expects inputs in the format (N, S, E) + # where S is the source sequence length, N is the batch size, E is the + # embedding dimension + x = x.permute(0, 2, 1) + + return x + + def forward(self, x: torch.Tensor): + out = {} + + # Reshape and permute the input tensor + x = self._process_input(x) + n = x.shape[0] + + # Expand the class token to the full batch + batch_class_token = self.class_token.expand(n, -1, -1) + x = torch.cat([batch_class_token, x], dim=1) + + + x = self.encoder(x) + img_feature = x[:,1:] + H = W = int(self.image_size / self.patch_size) + out['f4'] = img_feature.view(n, H, W, self.hidden_dim).permute(0,3,1,2) + + # Classifier "token" as used by standard language architectures + x = x[:, 0] + out['penultimate'] = x + + x = self.heads(x) # I checked that for all pretrained ViT, this is just a fc + out['logits'] = x + + return out + + +def _vision_transformer( + arch: str, + patch_size: int, + num_layers: int, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + pretrained: bool, + progress: bool, + **kwargs: Any, +) -> VisionTransformer: + image_size = kwargs.pop("image_size", 224) + + model = VisionTransformer( + image_size=image_size, + patch_size=patch_size, + num_layers=num_layers, + num_heads=num_heads, + hidden_dim=hidden_dim, + mlp_dim=mlp_dim, + **kwargs, + ) + + if pretrained: + if arch not in model_urls: + raise ValueError(f"No checkpoint is available for model type '{arch}'!") + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) + model.load_state_dict(state_dict) + + return model + + +def vit_b_16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: + """ + Constructs a vit_b_16 architecture from + `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _vision_transformer( + arch="vit_b_16", + patch_size=16, + num_layers=12, + num_heads=12, + hidden_dim=768, + mlp_dim=3072, + pretrained=pretrained, + progress=progress, + **kwargs, + ) + + +def vit_b_32(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: + """ + Constructs a vit_b_32 architecture from + `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _vision_transformer( + arch="vit_b_32", + patch_size=32, + num_layers=12, + num_heads=12, + hidden_dim=768, + mlp_dim=3072, + pretrained=pretrained, + progress=progress, + **kwargs, + ) + + +def vit_l_16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: + """ + Constructs a vit_l_16 architecture from + `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _vision_transformer( + arch="vit_l_16", + patch_size=16, + num_layers=24, + num_heads=16, + hidden_dim=1024, + mlp_dim=4096, + pretrained=pretrained, + progress=progress, + **kwargs, + ) + + +def vit_l_32(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: + """ + Constructs a vit_l_32 architecture from + `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _vision_transformer( + arch="vit_l_32", + patch_size=32, + num_layers=24, + num_heads=16, + hidden_dim=1024, + mlp_dim=4096, + pretrained=pretrained, + progress=progress, + **kwargs, + ) + + +def interpolate_embeddings( + image_size: int, + patch_size: int, + model_state: "OrderedDict[str, torch.Tensor]", + interpolation_mode: str = "bicubic", + reset_heads: bool = False, +) -> "OrderedDict[str, torch.Tensor]": + """This function helps interpolating positional embeddings during checkpoint loading, + especially when you want to apply a pre-trained model on images with different resolution. + + Args: + image_size (int): Image size of the new model. + patch_size (int): Patch size of the new model. + model_state (OrderedDict[str, torch.Tensor]): State dict of the pre-trained model. + interpolation_mode (str): The algorithm used for upsampling. Default: bicubic. + reset_heads (bool): If true, not copying the state of heads. Default: False. + + Returns: + OrderedDict[str, torch.Tensor]: A state dict which can be loaded into the new model. + """ + # Shape of pos_embedding is (1, seq_length, hidden_dim) + pos_embedding = model_state["encoder.pos_embedding"] + n, seq_length, hidden_dim = pos_embedding.shape + if n != 1: + raise ValueError(f"Unexpected position embedding shape: {pos_embedding.shape}") + + new_seq_length = (image_size // patch_size) ** 2 + 1 + + # Need to interpolate the weights for the position embedding. + # We do this by reshaping the positions embeddings to a 2d grid, performing + # an interpolation in the (h, w) space and then reshaping back to a 1d grid. + if new_seq_length != seq_length: + # The class token embedding shouldn't be interpolated so we split it up. + seq_length -= 1 + new_seq_length -= 1 + pos_embedding_token = pos_embedding[:, :1, :] + pos_embedding_img = pos_embedding[:, 1:, :] + + # (1, seq_length, hidden_dim) -> (1, hidden_dim, seq_length) + pos_embedding_img = pos_embedding_img.permute(0, 2, 1) + seq_length_1d = int(math.sqrt(seq_length)) + torch._assert(seq_length_1d * seq_length_1d == seq_length, "seq_length is not a perfect square!") + + # (1, hidden_dim, seq_length) -> (1, hidden_dim, seq_l_1d, seq_l_1d) + pos_embedding_img = pos_embedding_img.reshape(1, hidden_dim, seq_length_1d, seq_length_1d) + new_seq_length_1d = image_size // patch_size + + # Perform interpolation. + # (1, hidden_dim, seq_l_1d, seq_l_1d) -> (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) + new_pos_embedding_img = nn.functional.interpolate( + pos_embedding_img, + size=new_seq_length_1d, + mode=interpolation_mode, + align_corners=True, + ) + + # (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) -> (1, hidden_dim, new_seq_length) + new_pos_embedding_img = new_pos_embedding_img.reshape(1, hidden_dim, new_seq_length) + + # (1, hidden_dim, new_seq_length) -> (1, new_seq_length, hidden_dim) + new_pos_embedding_img = new_pos_embedding_img.permute(0, 2, 1) + new_pos_embedding = torch.cat([pos_embedding_token, new_pos_embedding_img], dim=1) + + model_state["encoder.pos_embedding"] = new_pos_embedding + + if reset_heads: + model_state_copy: "OrderedDict[str, torch.Tensor]" = OrderedDict() + for k, v in model_state.items(): + if not k.startswith("heads"): + model_state_copy[k] = v + model_state = model_state_copy + + return model_state diff --git a/models/vision_transformer_misc.py b/models/vision_transformer_misc.py new file mode 100755 index 0000000000000000000000000000000000000000..7915f036c00f0d9c57c176e621afc9f1e69dcb30 --- /dev/null +++ b/models/vision_transformer_misc.py @@ -0,0 +1,163 @@ +from typing import Callable, List, Optional + +import torch +from torch import Tensor + +from .vision_transformer_utils import _log_api_usage_once + + +interpolate = torch.nn.functional.interpolate + + +# This is not in nn +class FrozenBatchNorm2d(torch.nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed + + Args: + num_features (int): Number of features ``C`` from an expected input of size ``(N, C, H, W)`` + eps (float): a value added to the denominator for numerical stability. Default: 1e-5 + """ + + def __init__( + self, + num_features: int, + eps: float = 1e-5, + ): + super().__init__() + _log_api_usage_once(self) + self.eps = eps + self.register_buffer("weight", torch.ones(num_features)) + self.register_buffer("bias", torch.zeros(num_features)) + self.register_buffer("running_mean", torch.zeros(num_features)) + self.register_buffer("running_var", torch.ones(num_features)) + + def _load_from_state_dict( + self, + state_dict: dict, + prefix: str, + local_metadata: dict, + strict: bool, + missing_keys: List[str], + unexpected_keys: List[str], + error_msgs: List[str], + ): + num_batches_tracked_key = prefix + "num_batches_tracked" + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + + def forward(self, x: Tensor) -> Tensor: + # move reshapes to the beginning + # to make it fuser-friendly + w = self.weight.reshape(1, -1, 1, 1) + b = self.bias.reshape(1, -1, 1, 1) + rv = self.running_var.reshape(1, -1, 1, 1) + rm = self.running_mean.reshape(1, -1, 1, 1) + scale = w * (rv + self.eps).rsqrt() + bias = b - rm * scale + return x * scale + bias + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.weight.shape[0]}, eps={self.eps})" + + +class ConvNormActivation(torch.nn.Sequential): + """ + Configurable block used for Convolution-Normalzation-Activation blocks. + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the Convolution-Normalzation-Activation block + kernel_size: (int, optional): Size of the convolving kernel. Default: 3 + stride (int, optional): Stride of the convolution. Default: 1 + padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in wich case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation`` + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolutiuon layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm2d`` + activation_layer (Callable[..., torch.nn.Module], optinal): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU`` + dilation (int): Spacing between kernel elements. Default: 1 + inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True`` + bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``. + + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 1, + padding: Optional[int] = None, + groups: int = 1, + norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d, + activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, + dilation: int = 1, + inplace: Optional[bool] = True, + bias: Optional[bool] = None, + ) -> None: + if padding is None: + padding = (kernel_size - 1) // 2 * dilation + if bias is None: + bias = norm_layer is None + layers = [ + torch.nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + ] + if norm_layer is not None: + layers.append(norm_layer(out_channels)) + if activation_layer is not None: + params = {} if inplace is None else {"inplace": inplace} + layers.append(activation_layer(**params)) + super().__init__(*layers) + _log_api_usage_once(self) + self.out_channels = out_channels + + +class SqueezeExcitation(torch.nn.Module): + """ + This block implements the Squeeze-and-Excitation block from https://arxiv.org/abs/1709.01507 (see Fig. 1). + Parameters ``activation``, and ``scale_activation`` correspond to ``delta`` and ``sigma`` in in eq. 3. + + Args: + input_channels (int): Number of channels in the input image + squeeze_channels (int): Number of squeeze channels + activation (Callable[..., torch.nn.Module], optional): ``delta`` activation. Default: ``torch.nn.ReLU`` + scale_activation (Callable[..., torch.nn.Module]): ``sigma`` activation. Default: ``torch.nn.Sigmoid`` + """ + + def __init__( + self, + input_channels: int, + squeeze_channels: int, + activation: Callable[..., torch.nn.Module] = torch.nn.ReLU, + scale_activation: Callable[..., torch.nn.Module] = torch.nn.Sigmoid, + ) -> None: + super().__init__() + _log_api_usage_once(self) + self.avgpool = torch.nn.AdaptiveAvgPool2d(1) + self.fc1 = torch.nn.Conv2d(input_channels, squeeze_channels, 1) + self.fc2 = torch.nn.Conv2d(squeeze_channels, input_channels, 1) + self.activation = activation() + self.scale_activation = scale_activation() + + def _scale(self, input: Tensor) -> Tensor: + scale = self.avgpool(input) + scale = self.fc1(scale) + scale = self.activation(scale) + scale = self.fc2(scale) + return self.scale_activation(scale) + + def forward(self, input: Tensor) -> Tensor: + scale = self._scale(input) + return scale * input diff --git a/models/vision_transformer_utils.py b/models/vision_transformer_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..6d3293d103d0e186a1244e7cc0c6e3bde63d1df3 --- /dev/null +++ b/models/vision_transformer_utils.py @@ -0,0 +1,549 @@ +import math +import pathlib +import warnings +from types import FunctionType +from typing import Any, BinaryIO, List, Optional, Tuple, Union + +import numpy as np +import torch +from PIL import Image, ImageColor, ImageDraw, ImageFont + +__all__ = [ + "make_grid", + "save_image", + "draw_bounding_boxes", + "draw_segmentation_masks", + "draw_keypoints", + "flow_to_image", +] + + +@torch.no_grad() +def make_grid( + tensor: Union[torch.Tensor, List[torch.Tensor]], + nrow: int = 8, + padding: int = 2, + normalize: bool = False, + value_range: Optional[Tuple[int, int]] = None, + scale_each: bool = False, + pad_value: float = 0.0, + **kwargs, +) -> torch.Tensor: + """ + Make a grid of images. + + Args: + tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W) + or a list of images all of the same size. + nrow (int, optional): Number of images displayed in each row of the grid. + The final grid size is ``(B / nrow, nrow)``. Default: ``8``. + padding (int, optional): amount of padding. Default: ``2``. + normalize (bool, optional): If True, shift the image to the range (0, 1), + by the min and max values specified by ``value_range``. Default: ``False``. + value_range (tuple, optional): tuple (min, max) where min and max are numbers, + then these numbers are used to normalize the image. By default, min and max + are computed from the tensor. + range (tuple. optional): + .. warning:: + This parameter was deprecated in ``0.12`` and will be removed in ``0.14``. Please use ``value_range`` + instead. + scale_each (bool, optional): If ``True``, scale each image in the batch of + images separately rather than the (min, max) over all images. Default: ``False``. + pad_value (float, optional): Value for the padded pixels. Default: ``0``. + + Returns: + grid (Tensor): the tensor containing grid of images. + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(make_grid) + if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): + raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}") + + if "range" in kwargs.keys(): + warnings.warn( + "The parameter 'range' is deprecated since 0.12 and will be removed in 0.14. " + "Please use 'value_range' instead." + ) + value_range = kwargs["range"] + + # if list of tensors, convert to a 4D mini-batch Tensor + if isinstance(tensor, list): + tensor = torch.stack(tensor, dim=0) + + if tensor.dim() == 2: # single image H x W + tensor = tensor.unsqueeze(0) + if tensor.dim() == 3: # single image + if tensor.size(0) == 1: # if single-channel, convert to 3-channel + tensor = torch.cat((tensor, tensor, tensor), 0) + tensor = tensor.unsqueeze(0) + + if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images + tensor = torch.cat((tensor, tensor, tensor), 1) + + if normalize is True: + tensor = tensor.clone() # avoid modifying tensor in-place + if value_range is not None: + assert isinstance( + value_range, tuple + ), "value_range has to be a tuple (min, max) if specified. min and max are numbers" + + def norm_ip(img, low, high): + img.clamp_(min=low, max=high) + img.sub_(low).div_(max(high - low, 1e-5)) + + def norm_range(t, value_range): + if value_range is not None: + norm_ip(t, value_range[0], value_range[1]) + else: + norm_ip(t, float(t.min()), float(t.max())) + + if scale_each is True: + for t in tensor: # loop over mini-batch dimension + norm_range(t, value_range) + else: + norm_range(tensor, value_range) + + assert isinstance(tensor, torch.Tensor) + if tensor.size(0) == 1: + return tensor.squeeze(0) + + # make the mini-batch of images into a grid + nmaps = tensor.size(0) + xmaps = min(nrow, nmaps) + ymaps = int(math.ceil(float(nmaps) / xmaps)) + height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding) + num_channels = tensor.size(1) + grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value) + k = 0 + for y in range(ymaps): + for x in range(xmaps): + if k >= nmaps: + break + # Tensor.copy_() is a valid method but seems to be missing from the stubs + # https://pytorch.org/docs/stable/tensors.html#torch.Tensor.copy_ + grid.narrow(1, y * height + padding, height - padding).narrow( # type: ignore[attr-defined] + 2, x * width + padding, width - padding + ).copy_(tensor[k]) + k = k + 1 + return grid + + +@torch.no_grad() +def save_image( + tensor: Union[torch.Tensor, List[torch.Tensor]], + fp: Union[str, pathlib.Path, BinaryIO], + format: Optional[str] = None, + **kwargs, +) -> None: + """ + Save a given Tensor into an image file. + + Args: + tensor (Tensor or list): Image to be saved. If given a mini-batch tensor, + saves the tensor as a grid of images by calling ``make_grid``. + fp (string or file object): A filename or a file object + format(Optional): If omitted, the format to use is determined from the filename extension. + If a file object was used instead of a filename, this parameter should always be used. + **kwargs: Other arguments are documented in ``make_grid``. + """ + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(save_image) + grid = make_grid(tensor, **kwargs) + # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer + ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() + im = Image.fromarray(ndarr) + im.save(fp, format=format) + + +@torch.no_grad() +def draw_bounding_boxes( + image: torch.Tensor, + boxes: torch.Tensor, + labels: Optional[List[str]] = None, + colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None, + fill: Optional[bool] = False, + width: int = 1, + font: Optional[str] = None, + font_size: int = 10, +) -> torch.Tensor: + + """ + Draws bounding boxes on given image. + The values of the input image should be uint8 between 0 and 255. + If fill is True, Resulting Tensor should be saved as PNG image. + + Args: + image (Tensor): Tensor of shape (C x H x W) and dtype uint8. + boxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that + the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and + `0 <= ymin < ymax < H`. + labels (List[str]): List containing the labels of bounding boxes. + colors (color or list of colors, optional): List containing the colors + of the boxes or single color for all boxes. The color can be represented as + PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. + By default, random colors are generated for boxes. + fill (bool): If `True` fills the bounding box with specified color. + width (int): Width of bounding box. + font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may + also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`, + `/System/Library/Fonts/` and `~/Library/Fonts/` on macOS. + font_size (int): The requested font size in points. + + Returns: + img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted. + """ + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(draw_bounding_boxes) + if not isinstance(image, torch.Tensor): + raise TypeError(f"Tensor expected, got {type(image)}") + elif image.dtype != torch.uint8: + raise ValueError(f"Tensor uint8 expected, got {image.dtype}") + elif image.dim() != 3: + raise ValueError("Pass individual images, not batches") + elif image.size(0) not in {1, 3}: + raise ValueError("Only grayscale and RGB images are supported") + + num_boxes = boxes.shape[0] + + if labels is None: + labels: Union[List[str], List[None]] = [None] * num_boxes # type: ignore[no-redef] + elif len(labels) != num_boxes: + raise ValueError( + f"Number of boxes ({num_boxes}) and labels ({len(labels)}) mismatch. Please specify labels for each box." + ) + + if colors is None: + colors = _generate_color_palette(num_boxes) + elif isinstance(colors, list): + if len(colors) < num_boxes: + raise ValueError(f"Number of colors ({len(colors)}) is less than number of boxes ({num_boxes}). ") + else: # colors specifies a single color for all boxes + colors = [colors] * num_boxes + + colors = [(ImageColor.getrgb(color) if isinstance(color, str) else color) for color in colors] + + # Handle Grayscale images + if image.size(0) == 1: + image = torch.tile(image, (3, 1, 1)) + + ndarr = image.permute(1, 2, 0).cpu().numpy() + img_to_draw = Image.fromarray(ndarr) + img_boxes = boxes.to(torch.int64).tolist() + + if fill: + draw = ImageDraw.Draw(img_to_draw, "RGBA") + else: + draw = ImageDraw.Draw(img_to_draw) + + txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size) + + for bbox, color, label in zip(img_boxes, colors, labels): # type: ignore[arg-type] + if fill: + fill_color = color + (100,) + draw.rectangle(bbox, width=width, outline=color, fill=fill_color) + else: + draw.rectangle(bbox, width=width, outline=color) + + if label is not None: + margin = width + 1 + draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=color, font=txt_font) + + return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) + + +@torch.no_grad() +def draw_segmentation_masks( + image: torch.Tensor, + masks: torch.Tensor, + alpha: float = 0.8, + colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None, +) -> torch.Tensor: + + """ + Draws segmentation masks on given RGB image. + The values of the input image should be uint8 between 0 and 255. + + Args: + image (Tensor): Tensor of shape (3, H, W) and dtype uint8. + masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool. + alpha (float): Float number between 0 and 1 denoting the transparency of the masks. + 0 means full transparency, 1 means no transparency. + colors (color or list of colors, optional): List containing the colors + of the masks or single color for all masks. The color can be represented as + PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. + By default, random colors are generated for each mask. + + Returns: + img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top. + """ + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(draw_segmentation_masks) + if not isinstance(image, torch.Tensor): + raise TypeError(f"The image must be a tensor, got {type(image)}") + elif image.dtype != torch.uint8: + raise ValueError(f"The image dtype must be uint8, got {image.dtype}") + elif image.dim() != 3: + raise ValueError("Pass individual images, not batches") + elif image.size()[0] != 3: + raise ValueError("Pass an RGB image. Other Image formats are not supported") + if masks.ndim == 2: + masks = masks[None, :, :] + if masks.ndim != 3: + raise ValueError("masks must be of shape (H, W) or (batch_size, H, W)") + if masks.dtype != torch.bool: + raise ValueError(f"The masks must be of dtype bool. Got {masks.dtype}") + if masks.shape[-2:] != image.shape[-2:]: + raise ValueError("The image and the masks must have the same height and width") + + num_masks = masks.size()[0] + if colors is not None and num_masks > len(colors): + raise ValueError(f"There are more masks ({num_masks}) than colors ({len(colors)})") + + if colors is None: + colors = _generate_color_palette(num_masks) + + if not isinstance(colors, list): + colors = [colors] + if not isinstance(colors[0], (tuple, str)): + raise ValueError("colors must be a tuple or a string, or a list thereof") + if isinstance(colors[0], tuple) and len(colors[0]) != 3: + raise ValueError("It seems that you passed a tuple of colors instead of a list of colors") + + out_dtype = torch.uint8 + + colors_ = [] + for color in colors: + if isinstance(color, str): + color = ImageColor.getrgb(color) + colors_.append(torch.tensor(color, dtype=out_dtype)) + + img_to_draw = image.detach().clone() + # TODO: There might be a way to vectorize this + for mask, color in zip(masks, colors_): + img_to_draw[:, mask] = color[:, None] + + out = image * (1 - alpha) + img_to_draw * alpha + return out.to(out_dtype) + + +@torch.no_grad() +def draw_keypoints( + image: torch.Tensor, + keypoints: torch.Tensor, + connectivity: Optional[List[Tuple[int, int]]] = None, + colors: Optional[Union[str, Tuple[int, int, int]]] = None, + radius: int = 2, + width: int = 3, +) -> torch.Tensor: + + """ + Draws Keypoints on given RGB image. + The values of the input image should be uint8 between 0 and 255. + + Args: + image (Tensor): Tensor of shape (3, H, W) and dtype uint8. + keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances, + in the format [x, y]. + connectivity (List[Tuple[int, int]]]): A List of tuple where, + each tuple contains pair of keypoints to be connected. + colors (str, Tuple): The color can be represented as + PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. + radius (int): Integer denoting radius of keypoint. + width (int): Integer denoting width of line connecting keypoints. + + Returns: + img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn. + """ + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(draw_keypoints) + if not isinstance(image, torch.Tensor): + raise TypeError(f"The image must be a tensor, got {type(image)}") + elif image.dtype != torch.uint8: + raise ValueError(f"The image dtype must be uint8, got {image.dtype}") + elif image.dim() != 3: + raise ValueError("Pass individual images, not batches") + elif image.size()[0] != 3: + raise ValueError("Pass an RGB image. Other Image formats are not supported") + + if keypoints.ndim != 3: + raise ValueError("keypoints must be of shape (num_instances, K, 2)") + + ndarr = image.permute(1, 2, 0).cpu().numpy() + img_to_draw = Image.fromarray(ndarr) + draw = ImageDraw.Draw(img_to_draw) + img_kpts = keypoints.to(torch.int64).tolist() + + for kpt_id, kpt_inst in enumerate(img_kpts): + for inst_id, kpt in enumerate(kpt_inst): + x1 = kpt[0] - radius + x2 = kpt[0] + radius + y1 = kpt[1] - radius + y2 = kpt[1] + radius + draw.ellipse([x1, y1, x2, y2], fill=colors, outline=None, width=0) + + if connectivity: + for connection in connectivity: + start_pt_x = kpt_inst[connection[0]][0] + start_pt_y = kpt_inst[connection[0]][1] + + end_pt_x = kpt_inst[connection[1]][0] + end_pt_y = kpt_inst[connection[1]][1] + + draw.line( + ((start_pt_x, start_pt_y), (end_pt_x, end_pt_y)), + width=width, + ) + + return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) + + +# Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization +@torch.no_grad() +def flow_to_image(flow: torch.Tensor) -> torch.Tensor: + + """ + Converts a flow to an RGB image. + + Args: + flow (Tensor): Flow of shape (N, 2, H, W) or (2, H, W) and dtype torch.float. + + Returns: + img (Tensor): Image Tensor of dtype uint8 where each color corresponds + to a given flow direction. Shape is (N, 3, H, W) or (3, H, W) depending on the input. + """ + + if flow.dtype != torch.float: + raise ValueError(f"Flow should be of dtype torch.float, got {flow.dtype}.") + + orig_shape = flow.shape + if flow.ndim == 3: + flow = flow[None] # Add batch dim + + if flow.ndim != 4 or flow.shape[1] != 2: + raise ValueError(f"Input flow should have shape (2, H, W) or (N, 2, H, W), got {orig_shape}.") + + max_norm = torch.sum(flow ** 2, dim=1).sqrt().max() + epsilon = torch.finfo((flow).dtype).eps + normalized_flow = flow / (max_norm + epsilon) + img = _normalized_flow_to_image(normalized_flow) + + if len(orig_shape) == 3: + img = img[0] # Remove batch dim + return img + + +@torch.no_grad() +def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor: + + """ + Converts a batch of normalized flow to an RGB image. + + Args: + normalized_flow (torch.Tensor): Normalized flow tensor of shape (N, 2, H, W) + Returns: + img (Tensor(N, 3, H, W)): Flow visualization image of dtype uint8. + """ + + N, _, H, W = normalized_flow.shape + device = normalized_flow.device + flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8, device=device) + colorwheel = _make_colorwheel().to(device) # shape [55x3] + num_cols = colorwheel.shape[0] + norm = torch.sum(normalized_flow ** 2, dim=1).sqrt() + a = torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) / torch.pi + fk = (a + 1) / 2 * (num_cols - 1) + k0 = torch.floor(fk).to(torch.long) + k1 = k0 + 1 + k1[k1 == num_cols] = 0 + f = fk - k0 + + for c in range(colorwheel.shape[1]): + tmp = colorwheel[:, c] + col0 = tmp[k0] / 255.0 + col1 = tmp[k1] / 255.0 + col = (1 - f) * col0 + f * col1 + col = 1 - norm * (1 - col) + flow_image[:, c, :, :] = torch.floor(255 * col) + return flow_image + + +def _make_colorwheel() -> torch.Tensor: + """ + Generates a color wheel for optical flow visualization as presented in: + Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) + URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf. + + Returns: + colorwheel (Tensor[55, 3]): Colorwheel Tensor. + """ + + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + colorwheel = torch.zeros((ncols, 3)) + col = 0 + + # RY + colorwheel[0:RY, 0] = 255 + colorwheel[0:RY, 1] = torch.floor(255 * torch.arange(0, RY) / RY) + col = col + RY + # YG + colorwheel[col : col + YG, 0] = 255 - torch.floor(255 * torch.arange(0, YG) / YG) + colorwheel[col : col + YG, 1] = 255 + col = col + YG + # GC + colorwheel[col : col + GC, 1] = 255 + colorwheel[col : col + GC, 2] = torch.floor(255 * torch.arange(0, GC) / GC) + col = col + GC + # CB + colorwheel[col : col + CB, 1] = 255 - torch.floor(255 * torch.arange(CB) / CB) + colorwheel[col : col + CB, 2] = 255 + col = col + CB + # BM + colorwheel[col : col + BM, 2] = 255 + colorwheel[col : col + BM, 0] = torch.floor(255 * torch.arange(0, BM) / BM) + col = col + BM + # MR + colorwheel[col : col + MR, 2] = 255 - torch.floor(255 * torch.arange(MR) / MR) + colorwheel[col : col + MR, 0] = 255 + return colorwheel + + +def _generate_color_palette(num_objects: int): + palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) + return [tuple((i * palette) % 255) for i in range(num_objects)] + + +def _log_api_usage_once(obj: Any) -> None: + + """ + Logs API usage(module and name) within an organization. + In a large ecosystem, it's often useful to track the PyTorch and + TorchVision APIs usage. This API provides the similar functionality to the + logging module in the Python stdlib. It can be used for debugging purpose + to log which methods are used and by default it is inactive, unless the user + manually subscribes a logger via the `SetAPIUsageLogger method `_. + Please note it is triggered only once for the same API call within a process. + It does not collect any data from open-source users since it is no-op by default. + For more information, please refer to + * PyTorch note: https://pytorch.org/docs/stable/notes/large_scale_deployments.html#api-usage-logging; + * Logging policy: https://github.com/pytorch/vision/issues/5052; + + Args: + obj (class instance or method): an object to extract info from. + """ + if not obj.__module__.startswith("torchvision"): + return + name = obj.__class__.__name__ + if isinstance(obj, FunctionType): + name = obj.__name__ + torch._C._log_api_usage_once(f"{obj.__module__}.{name}") diff --git a/networks/__init__.py b/networks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/networks/__pycache__/__init__.cpython-38.pyc b/networks/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2601138a7e8c8c742bd98367c8882c9f8b3bd9b Binary files /dev/null and b/networks/__pycache__/__init__.cpython-38.pyc differ diff --git a/networks/__pycache__/base_model.cpython-38.pyc b/networks/__pycache__/base_model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ed98901e83b8822e9f39b68d428761a44c2b364 Binary files /dev/null and b/networks/__pycache__/base_model.cpython-38.pyc differ diff --git a/networks/__pycache__/trainer.cpython-38.pyc b/networks/__pycache__/trainer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47bd1b4978076e787fa427890635e4d91c7c601a Binary files /dev/null and b/networks/__pycache__/trainer.cpython-38.pyc differ diff --git a/networks/base_model.py b/networks/base_model.py new file mode 100755 index 0000000000000000000000000000000000000000..684bdd31004eb9d5664da1aba08dc3ba3b7c4d80 --- /dev/null +++ b/networks/base_model.py @@ -0,0 +1,58 @@ +import os +import torch +import torch.nn as nn +from torch.nn import init +from torch.optim import lr_scheduler + + +class BaseModel(nn.Module): + def __init__(self, opt): + super(BaseModel, self).__init__() + self.opt = opt + self.total_steps = 0 + self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) + self.device = torch.device('cuda:{}'.format(opt.gpu_ids[0])) if opt.gpu_ids else torch.device('cpu') + + def save_networks(self, save_filename): + save_path = os.path.join(self.save_dir, save_filename) + + # serialize model and optimizer to dict + state_dict = { + 'model': self.model.state_dict(), + 'optimizer' : self.optimizer.state_dict(), + 'total_steps' : self.total_steps, + } + + torch.save(state_dict, save_path) + + + def eval(self): + self.model.eval() + + def test(self): + with torch.no_grad(): + self.forward() + + +def init_weights(net, init_type='normal', gain=0.02): + def init_func(m): + classname = m.__class__.__name__ + if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): + if init_type == 'normal': + init.normal_(m.weight.data, 0.0, gain) + elif init_type == 'xavier': + init.xavier_normal_(m.weight.data, gain=gain) + elif init_type == 'kaiming': + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + init.orthogonal_(m.weight.data, gain=gain) + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + if hasattr(m, 'bias') and m.bias is not None: + init.constant_(m.bias.data, 0.0) + elif classname.find('BatchNorm2d') != -1: + init.normal_(m.weight.data, 1.0, gain) + init.constant_(m.bias.data, 0.0) + + print('initialize network with %s' % init_type) + net.apply(init_func) diff --git a/networks/lpf.py b/networks/lpf.py new file mode 100755 index 0000000000000000000000000000000000000000..f64030bd9a73786f249e03b4d6ce02b32d5ecf92 --- /dev/null +++ b/networks/lpf.py @@ -0,0 +1,120 @@ +# Copyright (c) 2019, Adobe Inc. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike +# 4.0 International Public License. To view a copy of this license, visit +# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode. + +import torch +import torch.nn.parallel +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +from IPython import embed + +class Downsample(nn.Module): + def __init__(self, pad_type='reflect', filt_size=3, stride=2, channels=None, pad_off=0): + super(Downsample, self).__init__() + self.filt_size = filt_size + self.pad_off = pad_off + self.pad_sizes = [int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)), int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2))] + self.pad_sizes = [pad_size+pad_off for pad_size in self.pad_sizes] + self.stride = stride + self.off = int((self.stride-1)/2.) + self.channels = channels + + # print('Filter size [%i]'%filt_size) + if(self.filt_size==1): + a = np.array([1.,]) + elif(self.filt_size==2): + a = np.array([1., 1.]) + elif(self.filt_size==3): + a = np.array([1., 2., 1.]) + elif(self.filt_size==4): + a = np.array([1., 3., 3., 1.]) + elif(self.filt_size==5): + a = np.array([1., 4., 6., 4., 1.]) + elif(self.filt_size==6): + a = np.array([1., 5., 10., 10., 5., 1.]) + elif(self.filt_size==7): + a = np.array([1., 6., 15., 20., 15., 6., 1.]) + + filt = torch.Tensor(a[:,None]*a[None,:]) + filt = filt/torch.sum(filt) + self.register_buffer('filt', filt[None,None,:,:].repeat((self.channels,1,1,1))) + + self.pad = get_pad_layer(pad_type)(self.pad_sizes) + + def forward(self, inp): + if(self.filt_size==1): + if(self.pad_off==0): + return inp[:,:,::self.stride,::self.stride] + else: + return self.pad(inp)[:,:,::self.stride,::self.stride] + else: + return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1]) + +def get_pad_layer(pad_type): + if(pad_type in ['refl','reflect']): + PadLayer = nn.ReflectionPad2d + elif(pad_type in ['repl','replicate']): + PadLayer = nn.ReplicationPad2d + elif(pad_type=='zero'): + PadLayer = nn.ZeroPad2d + else: + print('Pad type [%s] not recognized'%pad_type) + return PadLayer + + +class Downsample1D(nn.Module): + def __init__(self, pad_type='reflect', filt_size=3, stride=2, channels=None, pad_off=0): + super(Downsample1D, self).__init__() + self.filt_size = filt_size + self.pad_off = pad_off + self.pad_sizes = [int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2))] + self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes] + self.stride = stride + self.off = int((self.stride - 1) / 2.) + self.channels = channels + + # print('Filter size [%i]' % filt_size) + if(self.filt_size == 1): + a = np.array([1., ]) + elif(self.filt_size == 2): + a = np.array([1., 1.]) + elif(self.filt_size == 3): + a = np.array([1., 2., 1.]) + elif(self.filt_size == 4): + a = np.array([1., 3., 3., 1.]) + elif(self.filt_size == 5): + a = np.array([1., 4., 6., 4., 1.]) + elif(self.filt_size == 6): + a = np.array([1., 5., 10., 10., 5., 1.]) + elif(self.filt_size == 7): + a = np.array([1., 6., 15., 20., 15., 6., 1.]) + + filt = torch.Tensor(a) + filt = filt / torch.sum(filt) + self.register_buffer('filt', filt[None, None, :].repeat((self.channels, 1, 1))) + + self.pad = get_pad_layer_1d(pad_type)(self.pad_sizes) + + def forward(self, inp): + if(self.filt_size == 1): + if(self.pad_off == 0): + return inp[:, :, ::self.stride] + else: + return self.pad(inp)[:, :, ::self.stride] + else: + return F.conv1d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1]) + + +def get_pad_layer_1d(pad_type): + if(pad_type in ['refl', 'reflect']): + PadLayer = nn.ReflectionPad1d + elif(pad_type in ['repl', 'replicate']): + PadLayer = nn.ReplicationPad1d + elif(pad_type == 'zero'): + PadLayer = nn.ZeroPad1d + else: + print('Pad type [%s] not recognized' % pad_type) + return PadLayer diff --git a/networks/resnet_lpf.py b/networks/resnet_lpf.py new file mode 100755 index 0000000000000000000000000000000000000000..f9e34254eadc00e701245d03ffd86c25c114ce3f --- /dev/null +++ b/networks/resnet_lpf.py @@ -0,0 +1,313 @@ +# This code is built from the PyTorch examples repository: https://github.com/pytorch/vision/tree/master/torchvision/models. +# Copyright (c) 2017 Torch Contributors. +# The Pytorch examples are available under the BSD 3-Clause License. +# +# ========================================================================================== +# +# Adobe’s modifications are Copyright 2019 Adobe. All rights reserved. +# Adobe’s modifications are licensed under the Creative Commons Attribution-NonCommercial-ShareAlike +# 4.0 International Public License (CC-NC-SA-4.0). To view a copy of the license, visit +# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode. +# +# ========================================================================================== +# +# BSD-3 License +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + +import torch.nn as nn +import torch.utils.model_zoo as model_zoo +from .lpf import * + +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', + 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d'] + + +# model_urls = { +# 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', +# 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', +# 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', +# 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', +# 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', +# } + + +def conv3x3(in_planes, out_planes, stride=1, groups=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, groups=groups, bias=False) + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, norm_layer=None, filter_size=1): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1: + raise ValueError('BasicBlock only supports groups=1') + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + if(stride==1): + self.conv2 = conv3x3(planes,planes) + else: + self.conv2 = nn.Sequential(Downsample(filt_size=filter_size, stride=stride, channels=planes), + conv3x3(planes, planes),) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, norm_layer=None, filter_size=1): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, planes) + self.bn1 = norm_layer(planes) + self.conv2 = conv3x3(planes, planes, groups) # stride moved + self.bn2 = norm_layer(planes) + if(stride==1): + self.conv3 = conv1x1(planes, planes * self.expansion) + else: + self.conv3 = nn.Sequential(Downsample(filt_size=filter_size, stride=stride, channels=planes), + conv1x1(planes, planes * self.expansion)) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, + groups=1, width_per_group=64, norm_layer=None, filter_size=1, pool_only=True): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + planes = [int(width_per_group * groups * 2 ** i) for i in range(4)] + self.inplanes = planes[0] + + if(pool_only): + self.conv1 = nn.Conv2d(3, planes[0], kernel_size=7, stride=2, padding=3, bias=False) + else: + self.conv1 = nn.Conv2d(3, planes[0], kernel_size=7, stride=1, padding=3, bias=False) + self.bn1 = norm_layer(planes[0]) + self.relu = nn.ReLU(inplace=True) + + if(pool_only): + self.maxpool = nn.Sequential(*[nn.MaxPool2d(kernel_size=2, stride=1), + Downsample(filt_size=filter_size, stride=2, channels=planes[0])]) + else: + self.maxpool = nn.Sequential(*[Downsample(filt_size=filter_size, stride=2, channels=planes[0]), + nn.MaxPool2d(kernel_size=2, stride=1), + Downsample(filt_size=filter_size, stride=2, channels=planes[0])]) + + self.layer1 = self._make_layer(block, planes[0], layers[0], groups=groups, norm_layer=norm_layer) + self.layer2 = self._make_layer(block, planes[1], layers[1], stride=2, groups=groups, norm_layer=norm_layer, filter_size=filter_size) + self.layer3 = self._make_layer(block, planes[2], layers[2], stride=2, groups=groups, norm_layer=norm_layer, filter_size=filter_size) + self.layer4 = self._make_layer(block, planes[3], layers[3], stride=2, groups=groups, norm_layer=norm_layer, filter_size=filter_size) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(planes[3] * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + if(m.in_channels!=m.out_channels or m.out_channels!=m.groups or m.bias is not None): + # don't want to reinitialize downsample layers, code assuming normal conv layers will not have these characteristics + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + else: + print('Not initializing') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, groups=1, norm_layer=None, filter_size=1): + if norm_layer is None: + norm_layer = nn.BatchNorm2d + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + # downsample = nn.Sequential( + # conv1x1(self.inplanes, planes * block.expansion, stride, filter_size=filter_size), + # norm_layer(planes * block.expansion), + # ) + + downsample = [Downsample(filt_size=filter_size, stride=stride, channels=self.inplanes),] if(stride !=1) else [] + downsample += [conv1x1(self.inplanes, planes * block.expansion, 1), + norm_layer(planes * block.expansion)] + # print(downsample) + downsample = nn.Sequential(*downsample) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, groups, norm_layer, filter_size=filter_size)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=groups, norm_layer=norm_layer, filter_size=filter_size)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x + + +def resnet18(pretrained=False, filter_size=1, pool_only=True, **kwargs): + """Constructs a ResNet-18 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [2, 2, 2, 2], filter_size=filter_size, pool_only=pool_only, **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) + return model + + +def resnet34(pretrained=False, filter_size=1, pool_only=True, **kwargs): + """Constructs a ResNet-34 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [3, 4, 6, 3], filter_size=filter_size, pool_only=pool_only, **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) + return model + + +def resnet50(pretrained=False, filter_size=1, pool_only=True, **kwargs): + """Constructs a ResNet-50 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 6, 3], filter_size=filter_size, pool_only=pool_only, **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) + return model + + +def resnet101(pretrained=False, filter_size=1, pool_only=True, **kwargs): + """Constructs a ResNet-101 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 23, 3], filter_size=filter_size, pool_only=pool_only, **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) + return model + + +def resnet152(pretrained=False, filter_size=1, pool_only=True, **kwargs): + """Constructs a ResNet-152 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 8, 36, 3], filter_size=filter_size, pool_only=pool_only, **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) + return model + + +def resnext50_32x4d(pretrained=False, filter_size=1, pool_only=True, **kwargs): + model = ResNet(Bottleneck, [3, 4, 6, 3], groups=4, width_per_group=32, filter_size=filter_size, pool_only=pool_only, **kwargs) + # if pretrained: + # model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) + return model + + +def resnext101_32x8d(pretrained=False, filter_size=1, pool_only=True, **kwargs): + model = ResNet(Bottleneck, [3, 4, 23, 3], groups=8, width_per_group=32, filter_size=filter_size, pool_only=pool_only, **kwargs) + # if pretrained: + # model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) + return model diff --git a/networks/trainer.py b/networks/trainer.py new file mode 100755 index 0000000000000000000000000000000000000000..73d5bc0bcf5d48117e473900b38088c738f9c555 --- /dev/null +++ b/networks/trainer.py @@ -0,0 +1,74 @@ +import functools +import torch +import torch.nn as nn +from networks.base_model import BaseModel, init_weights +import sys +from models import get_model + +class Trainer(BaseModel): + def name(self): + return 'Trainer' + + def __init__(self, opt): + super(Trainer, self).__init__(opt) + self.opt = opt + self.model = get_model(opt.arch) + torch.nn.init.normal_(self.model.fc.weight.data, 0.0, opt.init_gain) + + if opt.fix_backbone: + params = [] + for name, p in self.model.named_parameters(): + if name=="fc.weight" or name=="fc.bias": + params.append(p) + else: + p.requires_grad = False + else: + print("Your backbone is not fixed. Are you sure you want to proceed? If this is a mistake, enable the --fix_backbone command during training and rerun") + import time + time.sleep(3) + params = self.model.parameters() + + + + if opt.optim == 'adam': + self.optimizer = torch.optim.AdamW(params, lr=opt.lr, betas=(opt.beta1, 0.999), weight_decay=opt.weight_decay) + elif opt.optim == 'sgd': + self.optimizer = torch.optim.SGD(params, lr=opt.lr, momentum=0.0, weight_decay=opt.weight_decay) + else: + raise ValueError("optim should be [adam, sgd]") + + self.loss_fn = nn.BCEWithLogitsLoss() + + self.model.to(opt.gpu_ids[0]) + + + def adjust_learning_rate(self, min_lr=1e-6): + for param_group in self.optimizer.param_groups: + param_group['lr'] /= 10. + if param_group['lr'] < min_lr: + return False + return True + + + def set_input(self, input): + self.input = input[0].to(self.device) + self.label = input[1].to(self.device).float() + + + def forward(self): + self.output = self.model(self.input) + self.output = self.output.view(-1).unsqueeze(1) + + + def get_loss(self): + return self.loss_fn(self.output.squeeze(1), self.label) + + def optimize_parameters(self): + self.forward() + self.loss = self.loss_fn(self.output.squeeze(1), self.label) + self.optimizer.zero_grad() + self.loss.backward() + self.optimizer.step() + + + diff --git a/options/__init__.py b/options/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/options/__pycache__/__init__.cpython-38.pyc b/options/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5862e99fd70ebde7cb51bb416fcdd9f07874b8d4 Binary files /dev/null and b/options/__pycache__/__init__.cpython-38.pyc differ diff --git a/options/__pycache__/base_options.cpython-38.pyc b/options/__pycache__/base_options.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be8185b34a207c3f5f5c6dce0cbd1100c2eb0e4b Binary files /dev/null and b/options/__pycache__/base_options.cpython-38.pyc differ diff --git a/options/__pycache__/train_options.cpython-38.pyc b/options/__pycache__/train_options.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70480e767f5bf4a83b2445d02fdbb91083d3fcf9 Binary files /dev/null and b/options/__pycache__/train_options.cpython-38.pyc differ diff --git a/options/base_options.py b/options/base_options.py new file mode 100755 index 0000000000000000000000000000000000000000..69d35b17cf409eac0a55a74ed187f23028cb93c5 --- /dev/null +++ b/options/base_options.py @@ -0,0 +1,117 @@ +import argparse +import os +import util +import torch + + +class BaseOptions(): + def __init__(self): + self.initialized = False + + def initialize(self, parser): + parser.add_argument('--mode', default='binary') + parser.add_argument('--arch', type=str, default='res50', help='see my_models/__init__.py') + parser.add_argument('--fix_backbone', action='store_true') + + # data augmentation + parser.add_argument('--rz_interp', default='bilinear') + parser.add_argument('--blur_prob', type=float, default=0.5) + parser.add_argument('--blur_sig', default='0.0,3.0') + parser.add_argument('--jpg_prob', type=float, default=0.5) + parser.add_argument('--jpg_method', default='cv2,pil') + parser.add_argument('--jpg_qual', default='30,100') + + + parser.add_argument('--real_list_path', default=None, help='only used if data_mode==ours: path for the list of real images, which should contain train.pickle and val.pickle') + parser.add_argument('--fake_list_path', default=None, help='only used if data_mode==ours: path for the list of fake images, which should contain train.pickle and val.pickle') + parser.add_argument('--wang2020_data_path', default=None, help='only used if data_mode==wang2020 it should contain train and test folders') + parser.add_argument('--data_mode', default='ours', help='wang2020 or ours') + parser.add_argument('--data_label', default='train', help='label to decide whether train or validation dataset') + parser.add_argument('--weight_decay', type=float, default=0.0, help='loss weight for l2 reg') + + parser.add_argument('--class_bal', action='store_true') # what is this ? + parser.add_argument('--batch_size', type=int, default=256, help='input batch size') + parser.add_argument('--loadSize', type=int, default=256, help='scale images to this size') + parser.add_argument('--cropSize', type=int, default=224, help='then crop to this size') + parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') + parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models') + parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data') + parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') + parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') + parser.add_argument('--resize_or_crop', type=str, default='scale_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop|none]') + parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation') + parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]') + parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') + parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{loadSize}') + self.initialized = True + return parser + + def gather_options(self): + # initialize parser with basic options + if not self.initialized: + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser = self.initialize(parser) + + # get the basic options + opt, _ = parser.parse_known_args() + self.parser = parser + + return parser.parse_args() + + def print_options(self, opt): + message = '' + message += '----------------- Options ---------------\n' + for k, v in sorted(vars(opt).items()): + comment = '' + default = self.parser.get_default(k) + if v != default: + comment = '\t[default: %s]' % str(default) + message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) + message += '----------------- End -------------------' + print(message) + + # save to the disk + expr_dir = os.path.join(opt.checkpoints_dir, opt.name) + util.mkdirs(expr_dir) + file_name = os.path.join(expr_dir, 'opt.txt') + with open(file_name, 'wt') as opt_file: + opt_file.write(message) + opt_file.write('\n') + + def parse(self, print_options=True): + + opt = self.gather_options() + opt.isTrain = self.isTrain # train or test + + # process opt.suffix + if opt.suffix: + suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' + opt.name = opt.name + suffix + + if print_options: + self.print_options(opt) + + # set gpu ids + str_ids = opt.gpu_ids.split(',') + opt.gpu_ids = [] + for str_id in str_ids: + id = int(str_id) + if id >= 0: + opt.gpu_ids.append(id) + if len(opt.gpu_ids) > 0: + torch.cuda.set_device(opt.gpu_ids[0]) + + # additional + #opt.classes = opt.classes.split(',') + opt.rz_interp = opt.rz_interp.split(',') + opt.blur_sig = [float(s) for s in opt.blur_sig.split(',')] + opt.jpg_method = opt.jpg_method.split(',') + opt.jpg_qual = [int(s) for s in opt.jpg_qual.split(',')] + if len(opt.jpg_qual) == 2: + opt.jpg_qual = list(range(opt.jpg_qual[0], opt.jpg_qual[1] + 1)) + elif len(opt.jpg_qual) > 2: + raise ValueError("Shouldn't have more than 2 values for --jpg_qual.") + + self.opt = opt + return self.opt diff --git a/options/test_options.py b/options/test_options.py new file mode 100755 index 0000000000000000000000000000000000000000..f824c7aae81d325bf81e16c5a5e2c1931ccd5b34 --- /dev/null +++ b/options/test_options.py @@ -0,0 +1,13 @@ +from .base_options import BaseOptions + + +class TestOptions(BaseOptions): + def initialize(self, parser): + parser = BaseOptions.initialize(self, parser) + parser.add_argument('--model_path') + parser.add_argument('--no_resize', action='store_true') + parser.add_argument('--no_crop', action='store_true') + parser.add_argument('--eval', action='store_true', help='use eval mode during test time.') + + self.isTrain = False + return parser diff --git a/options/train_options.py b/options/train_options.py new file mode 100755 index 0000000000000000000000000000000000000000..e7f1e54cea63ef424ad84d1ccaa1538e74119466 --- /dev/null +++ b/options/train_options.py @@ -0,0 +1,22 @@ +from .base_options import BaseOptions + + +class TrainOptions(BaseOptions): + def initialize(self, parser): + parser = BaseOptions.initialize(self, parser) + parser.add_argument('--earlystop_epoch', type=int, default=5) + parser.add_argument('--data_aug', action='store_true', help='if specified, perform additional data augmentation (photometric, blurring, jpegging)') + parser.add_argument('--optim', type=str, default='adam', help='optim to use [sgd, adam]') + parser.add_argument('--new_optim', action='store_true', help='new optimizer instead of loading the optim state') + parser.add_argument('--loss_freq', type=int, default=400, help='frequency of showing loss on tensorboard') + parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs') + parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') + parser.add_argument('--last_epoch', type=int, default=-1, help='starting epoch count for scheduler intialization') + parser.add_argument('--train_split', type=str, default='train', help='train, val, test, etc') + parser.add_argument('--val_split', type=str, default='val', help='train, val, test, etc') + parser.add_argument('--niter', type=int, default=100, help='total epoches') + parser.add_argument('--beta1', type=float, default=0.9, help='momentum term of adam') + parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam') + + self.isTrain = True + return parser diff --git a/pretrained_weights/fc_weights.pth b/pretrained_weights/fc_weights.pth new file mode 100644 index 0000000000000000000000000000000000000000..989708188fa14aa3f7ddfcddb579d7b9426d5e8e --- /dev/null +++ b/pretrained_weights/fc_weights.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:477100745713bcc957beb2b40859536859b6483fd6301b3b9293151b194c7847 +size 4083 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..3db312263449129a06fef67d07987644cfe04c32 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,159 @@ +# This file may be used to create an environment using: +# $ conda create --name --file +# platform: linux-64 +_libgcc_mutex=0.1=main +_openmp_mutex=4.5=1_gnu +absl-py=1.1.0=pypi_0 +anyio=3.6.1=pypi_0 +argon2-cffi=21.3.0=pypi_0 +argon2-cffi-bindings=21.2.0=pypi_0 +asttokens=2.0.5=pypi_0 +attrs=21.4.0=pypi_0 +babel=2.10.3=pypi_0 +backcall=0.2.0=pypi_0 +beautifulsoup4=4.11.1=pypi_0 +bleach=6.1.0=pypi_0 +brotlipy=0.7.0=py38h27cfd23_1003 +ca-certificates=2021.7.5=h06a4308_1 +cachetools=5.2.0=pypi_0 +certifi=2021.5.30=py38h06a4308_0 +cffi=1.14.6=py38h400218f_0 +chardet=4.0.0=py38h06a4308_1003 +conda=4.10.3=py38h06a4308_0 +conda-package-handling=1.7.3=py38h27cfd23_1 +cryptography=3.4.7=py38hd23ed53_0 +cycler=0.11.0=pypi_0 +debugpy=1.6.0=pypi_0 +decorator=5.1.1=pypi_0 +defusedxml=0.7.1=pypi_0 +docopt=0.6.2=pypi_0 +entrypoints=0.4=pypi_0 +executing=0.8.3=pypi_0 +fastjsonschema=2.15.3=pypi_0 +filelock=3.14.0=pypi_0 +fonttools=4.33.3=pypi_0 +ftfy=6.2.0=pypi_0 +gdown=5.2.0=pypi_0 +google-auth=2.8.0=pypi_0 +google-auth-oauthlib=0.4.6=pypi_0 +grpcio=1.46.3=pypi_0 +idna=2.10=pyhd3eb1b0_0 +imageio=2.34.1=pypi_0 +importlib-metadata=4.11.4=pypi_0 +importlib-resources=5.8.0=pypi_0 +ipykernel=6.15.0=pypi_0 +ipython=8.12.3=pypi_0 +ipython-genutils=0.2.0=pypi_0 +ipywidgets=7.7.0=pypi_0 +jedi=0.18.1=pypi_0 +jinja2=3.1.2=pypi_0 +joblib=1.4.2=pypi_0 +json5=0.9.8=pypi_0 +jsonschema=4.6.0=pypi_0 +jupyter-client=7.3.4=pypi_0 +jupyter-core=5.7.2=pypi_0 +jupyter-server=1.17.1=pypi_0 +jupyterlab=3.4.3=pypi_0 +jupyterlab-language-pack-zh-cn=3.4.post1=pypi_0 +jupyterlab-pygments=0.2.2=pypi_0 +jupyterlab-server=2.14.0=pypi_0 +jupyterlab-widgets=1.1.0=pypi_0 +kiwisolver=1.4.3=pypi_0 +lazy-loader=0.4=pypi_0 +ld_impl_linux-64=2.35.1=h7274673_9 +libffi=3.3=he6710b0_2 +libgcc-ng=9.3.0=h5101ec6_17 +libgomp=9.3.0=h5101ec6_17 +libstdcxx-ng=9.3.0=hd4cf53a_17 +markdown=3.3.7=pypi_0 +markupsafe=2.1.1=pypi_0 +matplotlib=3.5.2=pypi_0 +matplotlib-inline=0.1.3=pypi_0 +mistune=3.0.2=pypi_0 +nbclassic=0.3.7=pypi_0 +nbclient=0.6.4=pypi_0 +nbconvert=7.16.4=pypi_0 +nbformat=5.10.4=pypi_0 +ncurses=6.2=he6710b0_1 +nest-asyncio=1.5.5=pypi_0 +networkx=3.1=pypi_0 +notebook=6.4.12=pypi_0 +notebook-shim=0.1.0=pypi_0 +numpy=1.22.4=pypi_0 +oauthlib=3.2.0=pypi_0 +opencv-python=4.10.0.82=pypi_0 +openssl=1.1.1k=h27cfd23_0 +packaging=21.3=pypi_0 +pandocfilters=1.5.0=pypi_0 +parso=0.8.3=pypi_0 +pexpect=4.8.0=pypi_0 +pickleshare=0.7.5=pypi_0 +pillow=9.1.1=pypi_0 +pip=21.1.3=py38h06a4308_0 +pipreqs=0.5.0=pypi_0 +platformdirs=4.2.2=pypi_0 +prometheus-client=0.14.1=pypi_0 +prompt-toolkit=3.0.47=pypi_0 +protobuf=5.27.1=pypi_0 +psutil=5.9.1=pypi_0 +ptyprocess=0.7.0=pypi_0 +pure-eval=0.2.2=pypi_0 +pyasn1=0.4.8=pypi_0 +pyasn1-modules=0.2.8=pypi_0 +pycosat=0.6.3=py38h7b6447c_1 +pycparser=2.20=py_2 +pygments=2.12.0=pypi_0 +pyopenssl=20.0.1=pyhd3eb1b0_1 +pyparsing=3.0.9=pypi_0 +pyrsistent=0.18.1=pypi_0 +pysocks=1.7.1=py38h06a4308_0 +python=3.8.10=h12debd9_8 +python-dateutil=2.8.2=pypi_0 +pytz=2022.1=pypi_0 +pywavelets=1.4.1=pypi_0 +pyzmq=23.2.0=pypi_0 +readline=8.1=h27cfd23_0 +regex=2024.5.15=pypi_0 +requests=2.25.1=pyhd3eb1b0_0 +requests-oauthlib=1.3.1=pypi_0 +rsa=4.8=pypi_0 +ruamel_yaml=0.15.100=py38h27cfd23_0 +scikit-image=0.21.0=pypi_0 +scikit-learn=1.3.2=pypi_0 +scipy=1.10.1=pypi_0 +send2trash=1.8.0=pypi_0 +setuptools=52.0.0=py38h06a4308_0 +six=1.16.0=pyhd3eb1b0_0 +sklearn=0.0=pypi_0 +sniffio=1.2.0=pypi_0 +soupsieve=2.3.2.post1=pypi_0 +sqlite=3.36.0=hc218d9a_0 +stack-data=0.3.0=pypi_0 +supervisor=4.2.4=pypi_0 +tensorboard=2.9.1=pypi_0 +tensorboard-data-server=0.6.1=pypi_0 +tensorboard-plugin-wit=1.8.1=pypi_0 +tensorboardx=2.6.2.2=pypi_0 +terminado=0.15.0=pypi_0 +threadpoolctl=3.5.0=pypi_0 +tifffile=2023.7.10=pypi_0 +tinycss2=1.1.1=pypi_0 +tk=8.6.10=hbc83047_0 +torch=1.11.0+cu113=pypi_0 +torchvision=0.12.0+cu113=pypi_0 +tornado=6.1=pypi_0 +tqdm=4.61.2=pyhd3eb1b0_1 +traitlets=5.3.0=pypi_0 +typing-extensions=4.2.0=pypi_0 +urllib3=1.26.6=pyhd3eb1b0_1 +wcwidth=0.2.13=pypi_0 +webencodings=0.5.1=pypi_0 +websocket-client=1.3.3=pypi_0 +werkzeug=2.1.2=pypi_0 +wheel=0.36.2=pyhd3eb1b0_0 +widgetsnbextension=3.6.0=pypi_0 +xz=5.2.5=h7b6447c_0 +yaml=0.2.5=h7b6447c_0 +yarg=0.1.9=pypi_0 +zipp=3.8.0=pypi_0 +zlib=1.2.11=h7b6447c_3 diff --git a/resources/teaser.png b/resources/teaser.png new file mode 100644 index 0000000000000000000000000000000000000000..722e8c63aed79b79c38d6e20fa80c5539148297c Binary files /dev/null and b/resources/teaser.png differ diff --git a/test.sh b/test.sh new file mode 100644 index 0000000000000000000000000000000000000000..983d05b7c4cf0e8b0a977b835849a102fa8b75fc --- /dev/null +++ b/test.sh @@ -0,0 +1 @@ +CUDA_VISIBLE_DEVICES=0 python3 validate.py --arch=CLIP:ViT-L/14 --ckpt=pretrained_weights/fc_weights.pth --result_folder=clip_vitl14 diff --git a/train.py b/train.py new file mode 100755 index 0000000000000000000000000000000000000000..c24917730db61c90f62015ca68572b83219e2050 --- /dev/null +++ b/train.py @@ -0,0 +1,85 @@ +import os +import time +from tensorboardX import SummaryWriter + +from validate import validate +from data import create_dataloader +from earlystop import EarlyStopping +from networks.trainer import Trainer +from options.train_options import TrainOptions + + +"""Currently assumes jpg_prob, blur_prob 0 or 1""" +def get_val_opt(): + val_opt = TrainOptions().parse(print_options=False) + val_opt.isTrain = False + val_opt.no_resize = False + val_opt.no_crop = False + val_opt.serial_batches = True + val_opt.data_label = 'val' + val_opt.jpg_method = ['pil'] + if len(val_opt.blur_sig) == 2: + b_sig = val_opt.blur_sig + val_opt.blur_sig = [(b_sig[0] + b_sig[1]) / 2] + if len(val_opt.jpg_qual) != 1: + j_qual = val_opt.jpg_qual + val_opt.jpg_qual = [int((j_qual[0] + j_qual[-1]) / 2)] + + return val_opt + + + +if __name__ == '__main__': + opt = TrainOptions().parse() + val_opt = get_val_opt() + + model = Trainer(opt) + + data_loader = create_dataloader(opt) + val_loader = create_dataloader(val_opt) + + train_writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, "train")) + val_writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, "val")) + + early_stopping = EarlyStopping(patience=opt.earlystop_epoch, delta=-0.001, verbose=True) + start_time = time.time() + print ("Length of data loader: %d" %(len(data_loader))) + for epoch in range(opt.niter): + + for i, data in enumerate(data_loader): + model.total_steps += 1 + + model.set_input(data) + model.optimize_parameters() + + if model.total_steps % opt.loss_freq == 0: + print("Train loss: {} at step: {}".format(model.loss, model.total_steps)) + train_writer.add_scalar('loss', model.loss, model.total_steps) + print("Iter time: ", ((time.time()-start_time)/model.total_steps) ) + + if model.total_steps in [10,30,50,100,1000,5000,10000] and False: # save models at these iters + model.save_networks('model_iters_%s.pth' % model.total_steps) + + if epoch % opt.save_epoch_freq == 0: + print('saving the model at the end of epoch %d' % (epoch)) + model.save_networks( 'model_epoch_best.pth' ) + model.save_networks( 'model_epoch_%s.pth' % epoch ) + + # Validation + model.eval() + ap, r_acc, f_acc, acc = validate(model.model, val_loader) + val_writer.add_scalar('accuracy', acc, model.total_steps) + val_writer.add_scalar('ap', ap, model.total_steps) + print("(Val @ epoch {}) acc: {}; ap: {}".format(epoch, acc, ap)) + + early_stopping(acc, model) + if early_stopping.early_stop: + cont_train = model.adjust_learning_rate() + if cont_train: + print("Learning rate dropped by 10, continue training...") + early_stopping = EarlyStopping(patience=opt.earlystop_epoch, delta=-0.002, verbose=True) + else: + print("Early stopping.") + break + model.train() + diff --git a/util.py b/util.py new file mode 100644 index 0000000000000000000000000000000000000000..53f23f80e7f9a37c9be67cbd81c7c67f8594706a --- /dev/null +++ b/util.py @@ -0,0 +1,21 @@ +import os +import torch + + +def mkdirs(paths): + if isinstance(paths, list) and not isinstance(paths, str): + for path in paths: + mkdir(path) + else: + mkdir(paths) + + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) + + +def unnormalize(tens, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): + # assume tensor of shape NxCxHxW + return tens * torch.Tensor(std)[None, :, None, None] + torch.Tensor( + mean)[None, :, None, None] \ No newline at end of file diff --git a/validate.py b/validate.py new file mode 100644 index 0000000000000000000000000000000000000000..4887014457e584b5801e868a229db98bbd70117c --- /dev/null +++ b/validate.py @@ -0,0 +1,312 @@ +import argparse +from ast import arg +import os +import csv +import torch +import torchvision.transforms as transforms +import torch.utils.data +import numpy as np +from sklearn.metrics import average_precision_score, precision_recall_curve, accuracy_score +from torch.utils.data import Dataset +import sys +from models import get_model +from PIL import Image +import pickle +from tqdm import tqdm +from io import BytesIO +from copy import deepcopy +from dataset_paths import DATASET_PATHS +import random +import shutil +from scipy.ndimage.filters import gaussian_filter + +SEED = 0 +def set_seed(): + torch.manual_seed(SEED) + torch.cuda.manual_seed(SEED) + np.random.seed(SEED) + random.seed(SEED) + + +MEAN = { + "imagenet":[0.485, 0.456, 0.406], + "clip":[0.48145466, 0.4578275, 0.40821073] +} + +STD = { + "imagenet":[0.229, 0.224, 0.225], + "clip":[0.26862954, 0.26130258, 0.27577711] +} + + + + + +def find_best_threshold(y_true, y_pred): + "We assume first half is real 0, and the second half is fake 1" + + N = y_true.shape[0] + + if y_pred[0:N//2].max() <= y_pred[N//2:N].min(): # perfectly separable case + return (y_pred[0:N//2].max() + y_pred[N//2:N].min()) / 2 + + best_acc = 0 + best_thres = 0 + for thres in y_pred: + temp = deepcopy(y_pred) + temp[temp>=thres] = 1 + temp[temp= best_acc: + best_thres = thres + best_acc = acc + + return best_thres + + + +def png2jpg(img, quality): + out = BytesIO() + img.save(out, format='jpeg', quality=quality) # ranging from 0-95, 75 is default + img = Image.open(out) + # load from memory before ByteIO closes + img = np.array(img) + out.close() + return Image.fromarray(img) + + +def gaussian_blur(img, sigma): + img = np.array(img) + + gaussian_filter(img[:,:,0], output=img[:,:,0], sigma=sigma) + gaussian_filter(img[:,:,1], output=img[:,:,1], sigma=sigma) + gaussian_filter(img[:,:,2], output=img[:,:,2], sigma=sigma) + + return Image.fromarray(img) + + + +def calculate_acc(y_true, y_pred, thres): + r_acc = accuracy_score(y_true[y_true==0], y_pred[y_true==0] > thres) + f_acc = accuracy_score(y_true[y_true==1], y_pred[y_true==1] > thres) + acc = accuracy_score(y_true, y_pred > thres) + return r_acc, f_acc, acc + + +def validate(model, loader, find_thres=False): + + with torch.no_grad(): + y_true, y_pred = [], [] + print ("Length of dataset: %d" %(len(loader))) + for img, label in loader: + in_tens = img.cuda() + + y_pred.extend(model(in_tens).sigmoid().flatten().tolist()) + y_true.extend(label.flatten().tolist()) + + y_true, y_pred = np.array(y_true), np.array(y_pred) + + # ================== save this if you want to plot the curves =========== # + # torch.save( torch.stack( [torch.tensor(y_true), torch.tensor(y_pred)] ), 'baseline_predication_for_pr_roc_curve.pth' ) + # exit() + # =================================================================== # + + # Get AP + ap = average_precision_score(y_true, y_pred) + + # Acc based on 0.5 + r_acc0, f_acc0, acc0 = calculate_acc(y_true, y_pred, 0.5) + if not find_thres: + return ap, r_acc0, f_acc0, acc0 + + + # Acc based on the best thres + best_thres = find_best_threshold(y_true, y_pred) + r_acc1, f_acc1, acc1 = calculate_acc(y_true, y_pred, best_thres) + + return ap, r_acc0, f_acc0, acc0, r_acc1, f_acc1, acc1, best_thres + + + + + + +# = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = # + + + + +def recursively_read(rootdir, must_contain, exts=["png", "jpg", "JPEG", "jpeg", "bmp"]): + out = [] + for r, d, f in os.walk(rootdir): + for file in f: + if (file.split('.')[1] in exts) and (must_contain in os.path.join(r, file)): + out.append(os.path.join(r, file)) + return out + + +def get_list(path, must_contain=''): + if ".pickle" in path: + with open(path, 'rb') as f: + image_list = pickle.load(f) + image_list = [ item for item in image_list if must_contain in item ] + else: + image_list = recursively_read(path, must_contain) + return image_list + + + + + +class RealFakeDataset(Dataset): + def __init__(self, real_path, + fake_path, + data_mode, + max_sample, + arch, + jpeg_quality=None, + gaussian_sigma=None): + + assert data_mode in ["wang2020", "ours"] + self.jpeg_quality = jpeg_quality + self.gaussian_sigma = gaussian_sigma + + # = = = = = = data path = = = = = = = = = # + if type(real_path) == str and type(fake_path) == str: + real_list, fake_list = self.read_path(real_path, fake_path, data_mode, max_sample) + else: + real_list = [] + fake_list = [] + for real_p, fake_p in zip(real_path, fake_path): + real_l, fake_l = self.read_path(real_p, fake_p, data_mode, max_sample) + real_list += real_l + fake_list += fake_l + + self.total_list = real_list + fake_list + + + # = = = = = = label = = = = = = = = = # + + self.labels_dict = {} + for i in real_list: + self.labels_dict[i] = 0 + for i in fake_list: + self.labels_dict[i] = 1 + + stat_from = "imagenet" if arch.lower().startswith("imagenet") else "clip" + self.transform = transforms.Compose([ + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize( mean=MEAN[stat_from], std=STD[stat_from] ), + ]) + + + def read_path(self, real_path, fake_path, data_mode, max_sample): + + if data_mode == 'wang2020': + real_list = get_list(real_path, must_contain='0_real') + fake_list = get_list(fake_path, must_contain='1_fake') + else: + real_list = get_list(real_path) + fake_list = get_list(fake_path) + + + if max_sample is not None: + if (max_sample > len(real_list)) or (max_sample > len(fake_list)): + max_sample = 100 + print("not enough images, max_sample falling to 100") + random.shuffle(real_list) + random.shuffle(fake_list) + real_list = real_list[0:max_sample] + fake_list = fake_list[0:max_sample] + + assert len(real_list) == len(fake_list) + + return real_list, fake_list + + + + def __len__(self): + return len(self.total_list) + + def __getitem__(self, idx): + + img_path = self.total_list[idx] + + label = self.labels_dict[img_path] + img = Image.open(img_path).convert("RGB") + + if self.gaussian_sigma is not None: + img = gaussian_blur(img, self.gaussian_sigma) + if self.jpeg_quality is not None: + img = png2jpg(img, self.jpeg_quality) + + img = self.transform(img) + return img, label + + + + + +if __name__ == '__main__': + + + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('--real_path', type=str, default=None, help='dir name or a pickle') + parser.add_argument('--fake_path', type=str, default=None, help='dir name or a pickle') + parser.add_argument('--data_mode', type=str, default=None, help='wang2020 or ours') + parser.add_argument('--max_sample', type=int, default=1000, help='only check this number of images for both fake/real') + + parser.add_argument('--arch', type=str, default='res50') + parser.add_argument('--ckpt', type=str, default='./pretrained_weights/fc_weights.pth') + + parser.add_argument('--result_folder', type=str, default='result', help='') + parser.add_argument('--batch_size', type=int, default=128) + + parser.add_argument('--jpeg_quality', type=int, default=None, help="100, 90, 80, ... 30. Used to test robustness of our model. Not apply if None") + parser.add_argument('--gaussian_sigma', type=int, default=None, help="0,1,2,3,4. Used to test robustness of our model. Not apply if None") + + + opt = parser.parse_args() + + + if os.path.exists(opt.result_folder): + shutil.rmtree(opt.result_folder) + os.makedirs(opt.result_folder) + + model = get_model(opt.arch) + state_dict = torch.load(opt.ckpt, map_location='cpu') + model.fc.load_state_dict(state_dict) + print ("Model loaded..") + model.eval() + model.cuda() + + if (opt.real_path == None) or (opt.fake_path == None) or (opt.data_mode == None): + dataset_paths = DATASET_PATHS + else: + dataset_paths = [ dict(real_path=opt.real_path, fake_path=opt.fake_path, data_mode=opt.data_mode) ] + + + + for dataset_path in (dataset_paths): + set_seed() + + dataset = RealFakeDataset( dataset_path['real_path'], + dataset_path['fake_path'], + dataset_path['data_mode'], + opt.max_sample, + opt.arch, + jpeg_quality=opt.jpeg_quality, + gaussian_sigma=opt.gaussian_sigma, + ) + + loader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=False, num_workers=4) + ap, r_acc0, f_acc0, acc0, r_acc1, f_acc1, acc1, best_thres = validate(model, loader, find_thres=True) + + with open( os.path.join(opt.result_folder,'ap.txt'), 'a') as f: + f.write(dataset_path['key']+': ' + str(round(ap*100, 2))+'\n' ) + + with open( os.path.join(opt.result_folder,'acc0.txt'), 'a') as f: + f.write(dataset_path['key']+': ' + str(round(r_acc0*100, 2))+' '+str(round(f_acc0*100, 2))+' '+str(round(acc0*100, 2))+'\n' ) +