|
!pip install gradio |
|
import gradio as gr |
|
from transformers import AutoImageProcessor, SiglipForImageClassification |
|
from torch.optim import AdamW |
|
from PIL import Image |
|
import torch |
|
from torch.utils.data import Dataset, DataLoader |
|
import os |
|
|
|
|
|
model_name = "prithivMLmods/deepfake-detector-model-v1" |
|
processor = AutoImageProcessor.from_pretrained(model_name) |
|
model = SiglipForImageClassification.from_pretrained(model_name) |
|
model.train() |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
|
|
|
|
id2label = {0: "FAKE", 1: "REAL"} |
|
label2id = {"FAKE": 0, "REAL": 1} |
|
|
|
|
|
optimizer = AdamW(model.parameters(), lr=5e-6) |
|
|
|
|
|
class SingleImageDataset(Dataset): |
|
def __init__(self, image, label): |
|
self.image = image |
|
self.label = label |
|
|
|
def __len__(self): |
|
return 1 |
|
|
|
def __getitem__(self, idx): |
|
inputs = processor(images=self.image, return_tensors="pt") |
|
inputs = {k: v.squeeze(0) for k,v in inputs.items()} |
|
inputs['labels'] = torch.tensor(self.label) |
|
return inputs |
|
|
|
def fine_tune(image, correct_label): |
|
dataset = SingleImageDataset(image, correct_label) |
|
dataloader = DataLoader(dataset, batch_size=1) |
|
|
|
model.train() |
|
for epoch in range(1): |
|
for batch in dataloader: |
|
batch = {k: v.to(device) for k,v in batch.items()} |
|
outputs = model(**batch) |
|
loss = outputs.loss |
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
|
|
save_path = "./fine_tuned_model" |
|
os.makedirs(save_path, exist_ok=True) |
|
model.save_pretrained(save_path) |
|
processor.save_pretrained(save_path) |
|
return |
|
|
|
def predict(image): |
|
model.eval() |
|
inputs = processor(images=image, return_tensors="pt").to(device) |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
pred_class = logits.argmax(-1).item() |
|
return id2label[pred_class] |
|
|
|
def inference(image, feedback, correct_label_text): |
|
if image is None: |
|
return "Please upload an image.", None |
|
|
|
prediction = predict(image) |
|
message = f"Prediction: {prediction}" |
|
|
|
if feedback == "Wrong": |
|
if correct_label_text.upper() in label2id: |
|
correct_label = label2id[correct_label_text.upper()] |
|
fine_tune(image, correct_label) |
|
message += f" | Model fine-tuned with correct label: {correct_label_text.upper()}" |
|
else: |
|
message += " | Please enter a valid correct label (REAL or FAKE)." |
|
|
|
return message, image |
|
|
|
|
|
title = "Deepfake Detector with Interactive Feedback and Fine-tuning" |
|
|
|
iface = gr.Interface( |
|
fn=inference, |
|
inputs=[ |
|
gr.Image(type="pil", label="Upload Image"), |
|
gr.Radio(["Correct", "Wrong"], label="Is the prediction correct?", value="Correct"), |
|
gr.Textbox(label="If Wrong, enter correct label (REAL or FAKE)", lines=1, placeholder="REAL or FAKE") |
|
], |
|
outputs=[ |
|
gr.Textbox(label="Output"), |
|
gr.Image(type="pil", label="Uploaded Image") |
|
], |
|
title=title, |
|
live=False, |
|
allow_flagging="never" |
|
) |
|
|
|
iface.launch() |
|
|