File size: 3,332 Bytes
5632cc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
!pip install gradio
import gradio as gr
from transformers import AutoImageProcessor, SiglipForImageClassification
from torch.optim import AdamW
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import os

# Load model and processor
model_name = "prithivMLmods/deepfake-detector-model-v1"
processor = AutoImageProcessor.from_pretrained(model_name)
model = SiglipForImageClassification.from_pretrained(model_name)
model.train()

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Labels mapping
id2label = {0: "FAKE", 1: "REAL"}
label2id = {"FAKE": 0, "REAL": 1}

# Optimizer for fine-tuning
optimizer = AdamW(model.parameters(), lr=5e-6)

# Dataset class for single example fine-tuning
class SingleImageDataset(Dataset):
    def __init__(self, image, label):
        self.image = image
        self.label = label

    def __len__(self):
        return 1

    def __getitem__(self, idx):
        inputs = processor(images=self.image, return_tensors="pt")
        inputs = {k: v.squeeze(0) for k,v in inputs.items()}
        inputs['labels'] = torch.tensor(self.label)
        return inputs

def fine_tune(image, correct_label):
    dataset = SingleImageDataset(image, correct_label)
    dataloader = DataLoader(dataset, batch_size=1)

    model.train()
    for epoch in range(1):  # just 1 epoch for fast feedback
        for batch in dataloader:
            batch = {k: v.to(device) for k,v in batch.items()}
            outputs = model(**batch)
            loss = outputs.loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    # Save the updated model locally
    save_path = "./fine_tuned_model"
    os.makedirs(save_path, exist_ok=True)
    model.save_pretrained(save_path)
    processor.save_pretrained(save_path)
    return

def predict(image):
    model.eval()
    inputs = processor(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        pred_class = logits.argmax(-1).item()
    return id2label[pred_class]

def inference(image, feedback, correct_label_text):
    if image is None:
        return "Please upload an image.", None

    prediction = predict(image)
    message = f"Prediction: {prediction}"

    if feedback == "Wrong":
        if correct_label_text.upper() in label2id:
            correct_label = label2id[correct_label_text.upper()]
            fine_tune(image, correct_label)
            message += f" | Model fine-tuned with correct label: {correct_label_text.upper()}"
        else:
            message += " | Please enter a valid correct label (REAL or FAKE)."

    return message, image

# Gradio UI setup
title = "Deepfake Detector with Interactive Feedback and Fine-tuning"

iface = gr.Interface(
    fn=inference,
    inputs=[
        gr.Image(type="pil", label="Upload Image"),
        gr.Radio(["Correct", "Wrong"], label="Is the prediction correct?", value="Correct"),
        gr.Textbox(label="If Wrong, enter correct label (REAL or FAKE)", lines=1, placeholder="REAL or FAKE")
    ],
    outputs=[
        gr.Textbox(label="Output"),
        gr.Image(type="pil", label="Uploaded Image")
    ],
    title=title,
    live=False,
    allow_flagging="never"
)

iface.launch()