Spaces:
Sleeping
Sleeping
| import argparse | |
| import torch.nn | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| from .CNN.networks.resnet import resnet50 | |
| def predict_cnn(image, model_path, crop=None): | |
| model = resnet50(num_classes=1) | |
| state_dict = torch.load(model_path, map_location="cpu") | |
| model.load_state_dict(state_dict["model"]) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| model.eval() | |
| # Transform | |
| if crop is not None: | |
| trans_init = [transforms.CenterCrop(crop)] | |
| print("Cropping to [%i]" % crop) | |
| trans = transforms.Compose( | |
| trans_init | |
| + [ | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225], | |
| ), | |
| ], | |
| ) | |
| image = trans(image.convert("RGB")) | |
| with torch.no_grad(): | |
| in_tens = image.unsqueeze(0) | |
| prob = model(in_tens).sigmoid().item() | |
| return prob | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser( | |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |
| ) | |
| parser.add_argument("-f", "--file", default="examples_realfakedir") | |
| parser.add_argument( | |
| "-m", | |
| "--model_path", | |
| type=str, | |
| default="weights/blur_jpg_prob0.5.pth", | |
| ) | |
| parser.add_argument( | |
| "-c", | |
| "--crop", | |
| type=int, | |
| default=None, | |
| help="by default, do not crop. specify crop size", | |
| ) | |
| opt = parser.parse_args() | |
| prob = predict_cnn(Image.open(opt.file), opt.model_path, crop=opt.crop) | |
| print(f"probability of being synthetic: {prob * 100:.2f}%") | |