!pip install gradio import gradio as gr from transformers import AutoImageProcessor, SiglipForImageClassification from torch.optim import AdamW from PIL import Image import torch from torch.utils.data import Dataset, DataLoader import os # Load model and processor model_name = "prithivMLmods/deepfake-detector-model-v1" processor = AutoImageProcessor.from_pretrained(model_name) model = SiglipForImageClassification.from_pretrained(model_name) model.train() # Device setup device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # Labels mapping id2label = {0: "FAKE", 1: "REAL"} label2id = {"FAKE": 0, "REAL": 1} # Optimizer for fine-tuning optimizer = AdamW(model.parameters(), lr=5e-6) # Dataset class for single example fine-tuning class SingleImageDataset(Dataset): def __init__(self, image, label): self.image = image self.label = label def __len__(self): return 1 def __getitem__(self, idx): inputs = processor(images=self.image, return_tensors="pt") inputs = {k: v.squeeze(0) for k,v in inputs.items()} inputs['labels'] = torch.tensor(self.label) return inputs def fine_tune(image, correct_label): dataset = SingleImageDataset(image, correct_label) dataloader = DataLoader(dataset, batch_size=1) model.train() for epoch in range(1): # just 1 epoch for fast feedback for batch in dataloader: batch = {k: v.to(device) for k,v in batch.items()} outputs = model(**batch) loss = outputs.loss optimizer.zero_grad() loss.backward() optimizer.step() # Save the updated model locally save_path = "./fine_tuned_model" os.makedirs(save_path, exist_ok=True) model.save_pretrained(save_path) processor.save_pretrained(save_path) return def predict(image): model.eval() inputs = processor(images=image, return_tensors="pt").to(device) with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits pred_class = logits.argmax(-1).item() return id2label[pred_class] def inference(image, feedback, correct_label_text): if image is None: return "Please upload an image.", None prediction = predict(image) message = f"Prediction: {prediction}" if feedback == "Wrong": if correct_label_text.upper() in label2id: correct_label = label2id[correct_label_text.upper()] fine_tune(image, correct_label) message += f" | Model fine-tuned with correct label: {correct_label_text.upper()}" else: message += " | Please enter a valid correct label (REAL or FAKE)." return message, image # Gradio UI setup title = "Deepfake Detector with Interactive Feedback and Fine-tuning" iface = gr.Interface( fn=inference, inputs=[ gr.Image(type="pil", label="Upload Image"), gr.Radio(["Correct", "Wrong"], label="Is the prediction correct?", value="Correct"), gr.Textbox(label="If Wrong, enter correct label (REAL or FAKE)", lines=1, placeholder="REAL or FAKE") ], outputs=[ gr.Textbox(label="Output"), gr.Image(type="pil", label="Uploaded Image") ], title=title, live=False, allow_flagging="never" ) iface.launch()