TruthLens / app.py
Duncan222's picture
Create app.py
5632cc9 verified
!pip install gradio
import gradio as gr
from transformers import AutoImageProcessor, SiglipForImageClassification
from torch.optim import AdamW
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import os
# Load model and processor
model_name = "prithivMLmods/deepfake-detector-model-v1"
processor = AutoImageProcessor.from_pretrained(model_name)
model = SiglipForImageClassification.from_pretrained(model_name)
model.train()
# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Labels mapping
id2label = {0: "FAKE", 1: "REAL"}
label2id = {"FAKE": 0, "REAL": 1}
# Optimizer for fine-tuning
optimizer = AdamW(model.parameters(), lr=5e-6)
# Dataset class for single example fine-tuning
class SingleImageDataset(Dataset):
def __init__(self, image, label):
self.image = image
self.label = label
def __len__(self):
return 1
def __getitem__(self, idx):
inputs = processor(images=self.image, return_tensors="pt")
inputs = {k: v.squeeze(0) for k,v in inputs.items()}
inputs['labels'] = torch.tensor(self.label)
return inputs
def fine_tune(image, correct_label):
dataset = SingleImageDataset(image, correct_label)
dataloader = DataLoader(dataset, batch_size=1)
model.train()
for epoch in range(1): # just 1 epoch for fast feedback
for batch in dataloader:
batch = {k: v.to(device) for k,v in batch.items()}
outputs = model(**batch)
loss = outputs.loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Save the updated model locally
save_path = "./fine_tuned_model"
os.makedirs(save_path, exist_ok=True)
model.save_pretrained(save_path)
processor.save_pretrained(save_path)
return
def predict(image):
model.eval()
inputs = processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
pred_class = logits.argmax(-1).item()
return id2label[pred_class]
def inference(image, feedback, correct_label_text):
if image is None:
return "Please upload an image.", None
prediction = predict(image)
message = f"Prediction: {prediction}"
if feedback == "Wrong":
if correct_label_text.upper() in label2id:
correct_label = label2id[correct_label_text.upper()]
fine_tune(image, correct_label)
message += f" | Model fine-tuned with correct label: {correct_label_text.upper()}"
else:
message += " | Please enter a valid correct label (REAL or FAKE)."
return message, image
# Gradio UI setup
title = "Deepfake Detector with Interactive Feedback and Fine-tuning"
iface = gr.Interface(
fn=inference,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Radio(["Correct", "Wrong"], label="Is the prediction correct?", value="Correct"),
gr.Textbox(label="If Wrong, enter correct label (REAL or FAKE)", lines=1, placeholder="REAL or FAKE")
],
outputs=[
gr.Textbox(label="Output"),
gr.Image(type="pil", label="Uploaded Image")
],
title=title,
live=False,
allow_flagging="never"
)
iface.launch()