File size: 3,332 Bytes
5632cc9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
!pip install gradio
import gradio as gr
from transformers import AutoImageProcessor, SiglipForImageClassification
from torch.optim import AdamW
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import os
# Load model and processor
model_name = "prithivMLmods/deepfake-detector-model-v1"
processor = AutoImageProcessor.from_pretrained(model_name)
model = SiglipForImageClassification.from_pretrained(model_name)
model.train()
# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Labels mapping
id2label = {0: "FAKE", 1: "REAL"}
label2id = {"FAKE": 0, "REAL": 1}
# Optimizer for fine-tuning
optimizer = AdamW(model.parameters(), lr=5e-6)
# Dataset class for single example fine-tuning
class SingleImageDataset(Dataset):
def __init__(self, image, label):
self.image = image
self.label = label
def __len__(self):
return 1
def __getitem__(self, idx):
inputs = processor(images=self.image, return_tensors="pt")
inputs = {k: v.squeeze(0) for k,v in inputs.items()}
inputs['labels'] = torch.tensor(self.label)
return inputs
def fine_tune(image, correct_label):
dataset = SingleImageDataset(image, correct_label)
dataloader = DataLoader(dataset, batch_size=1)
model.train()
for epoch in range(1): # just 1 epoch for fast feedback
for batch in dataloader:
batch = {k: v.to(device) for k,v in batch.items()}
outputs = model(**batch)
loss = outputs.loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Save the updated model locally
save_path = "./fine_tuned_model"
os.makedirs(save_path, exist_ok=True)
model.save_pretrained(save_path)
processor.save_pretrained(save_path)
return
def predict(image):
model.eval()
inputs = processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
pred_class = logits.argmax(-1).item()
return id2label[pred_class]
def inference(image, feedback, correct_label_text):
if image is None:
return "Please upload an image.", None
prediction = predict(image)
message = f"Prediction: {prediction}"
if feedback == "Wrong":
if correct_label_text.upper() in label2id:
correct_label = label2id[correct_label_text.upper()]
fine_tune(image, correct_label)
message += f" | Model fine-tuned with correct label: {correct_label_text.upper()}"
else:
message += " | Please enter a valid correct label (REAL or FAKE)."
return message, image
# Gradio UI setup
title = "Deepfake Detector with Interactive Feedback and Fine-tuning"
iface = gr.Interface(
fn=inference,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Radio(["Correct", "Wrong"], label="Is the prediction correct?", value="Correct"),
gr.Textbox(label="If Wrong, enter correct label (REAL or FAKE)", lines=1, placeholder="REAL or FAKE")
],
outputs=[
gr.Textbox(label="Output"),
gr.Image(type="pil", label="Uploaded Image")
],
title=title,
live=False,
allow_flagging="never"
)
iface.launch()
|