DualSight-Demo / app.py
Abdu07's picture
Update app.py
7efb51a verified
raw
history blame
3.41 kB
import gradio as gr
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from huggingface_hub import hf_hub_download
import json
########################################
# 1. Define the Model Architecture
########################################
class MultiTaskModel(nn.Module):
def __init__(self, backbone, feature_dim, num_obj_classes):
super(MultiTaskModel, self).__init__()
self.backbone = backbone
self.obj_head = nn.Linear(feature_dim, num_obj_classes)
self.bin_head = nn.Linear(feature_dim, 2)
def forward(self, x):
feats = self.backbone(x)
obj_logits = self.obj_head(feats)
bin_logits = self.bin_head(feats)
return obj_logits, bin_logits
########################################
# 2. Load the Label Mapping and Set num_obj_classes
########################################
# Load the saved mapping from JSON
with open("obj_label_mapping.json", "r") as f:
obj_label_to_idx = json.load(f)
# Use the mapping as-is; do not override it.
num_obj_classes = len(obj_label_to_idx)
# Create the inverse mapping
idx_to_obj_label = {v: k for k, v in obj_label_to_idx.items()}
bin_label_names = ["AI-Generated", "Real"]
########################################
# 3. Define Validation Transforms
########################################
val_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
########################################
# 4. Reconstruct the Model and Load Weights
########################################
device = torch.device("cpu")
resnet = models.resnet50(pretrained=False)
resnet.fc = nn.Identity()
feature_dim = 2048
# Build the model architecture.
model = MultiTaskModel(resnet, feature_dim, num_obj_classes)
model.to(device)
repo_id = "Abdu07/multitask-model"
filename = "DualSight.pth" # Ensure this checkpoint is from training with the same num_obj_classes
weights_path = hf_hub_download(repo_id=repo_id, filename=filename)
state_dict = torch.load(weights_path, map_location="cpu")
model.load_state_dict(state_dict)
model.eval()
########################################
# 5. Define the Inference Function
########################################
def predict_image(img: Image.Image) -> str:
img = img.convert("RGB")
img_tensor = val_transforms(img).unsqueeze(0).to(device)
with torch.no_grad():
obj_logits, bin_logits = model(img_tensor)
obj_pred = torch.argmax(obj_logits, dim=1).item()
bin_pred = torch.argmax(bin_logits, dim=1).item()
obj_name = idx_to_obj_label.get(obj_pred, "Unknown")
bin_name = bin_label_names[bin_pred]
return f"Prediction: {obj_name} ({bin_name})"
########################################
# 6. Create Gradio UI
########################################
demo = gr.Interface(
fn=predict_image,
inputs=gr.Image(type="pil"),
outputs="text",
title="Multi-Task Image Classifier Trained by [Abdellahi El Moustapha](https://abmstpha.github.io/),
description="Upload an image to receive two predictions:\n1) The primary object in the image,\n2) Whether the image is AI-generated or Real."
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", share=True)