File size: 1,781 Bytes
55478d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import gradio as gr
import os
import csv
from models import get_model
import torch
import torchvision.transforms as transforms
import torch.utils.data
import numpy as np
import sys
from PIL import Image
# from detect_one_image import detect_one_image

MEAN = {
    "imagenet":[0.485, 0.456, 0.406],
    "clip":[0.48145466, 0.4578275, 0.40821073]
}

STD = {
    "imagenet":[0.229, 0.224, 0.225],
    "clip":[0.26862954, 0.26130258, 0.27577711]
}


def detect_one_image(model, image):

    """
    model = get_model('CLIP:ViT-L/14')
    state_dict = torch.load(ckpt, map_location='cpu')
    model.fc.load_state_dict(state_dict)
    print ("Model loaded..")
    model.eval()
    model.cuda()
    """
    # img = Image.open(image_path).convert("RGB")
    """
    if jpeg_quality is not None:
        img = png2jpg(img, jpeg_quality)
    """
    transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.CenterCrop(224),
            transforms.Normalize( mean=MEAN['clip'], std=STD['clip'] ),
        ])
    img = transform(image)
    img = img.to('cuda:0')

    detection_output = model(img)
    output = torch.sigmoid(detection_output)

    return output

def detect(image):
    # print(type(image))
    model = get_model('CLIP:ViT-L/14')
    state_dict = torch.load('./pretrained_weights/fc_weights.pth', map_location='cpu')
    model.fc.load_state_dict(state_dict)
    # model.load_state_dict(state_dict)
    # print ("Model loaded..")
    model.eval()
    model.cuda()
    output_tensor = detect_one_image(model, image)
    ai_likelihood = (100*output_tensor).item()
    return "The image is " + str(ai_likelihood) + r" % likely to be AI-generated."

demo = gr.Interface(
    fn=detect,
    inputs=["image"],
    outputs=["text"],
)

demo.launch()