DualSight-Demo / app.py
Abdu07's picture
Update app.py
4375fb7 verified
raw
history blame
3.46 kB
import gradio as gr
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from huggingface_hub import hf_hub_download
import json
########################################
# 1. Define the Model Architecture
########################################
class MultiTaskModel(nn.Module):
def __init__(self, backbone, feature_dim, num_obj_classes):
super(MultiTaskModel, self).__init__()
self.backbone = backbone
self.obj_head = nn.Linear(feature_dim, num_obj_classes)
self.bin_head = nn.Linear(feature_dim, 2)
def forward(self, x):
feats = self.backbone(x)
obj_logits = self.obj_head(feats)
bin_logits = self.bin_head(feats)
return obj_logits, bin_logits
########################################
# 2. Reconstruct the Model and Load Weights
########################################
# IMPORTANT: The checkpoint was saved with a single object class,
# so we set num_obj_classes to 1.
num_obj_classes = 1
device = torch.device("cpu")
resnet = models.resnet50(pretrained=False)
resnet.fc = nn.Identity()
feature_dim = 2048
model = MultiTaskModel(resnet, feature_dim, num_obj_classes)
model.to(device)
repo_id = "Abdu07/multitask-model"
filename = "DualSight.pth"
weights_path = hf_hub_download(repo_id=repo_id, filename=filename)
state_dict = torch.load(weights_path, map_location="cpu")
model.load_state_dict(state_dict)
model.eval()
########################################
# 3. Load Label Mapping and Define Transforms
########################################
# Attempt to load the mapping from JSON.
# If the mapping contains more than one label, we override it with a single-label mapping
try:
with open("obj_label_mapping.json", "r") as f:
obj_label_to_idx = json.load(f)
if len(obj_label_to_idx) != 1:
obj_label_to_idx = {"Detected Object": 0}
except Exception as e:
print("Error loading mapping, using default mapping. Error:", e)
obj_label_to_idx = {"Detected Object": 0}
idx_to_obj_label = {v: k for k, v in obj_label_to_idx.items()}
bin_label_names = ["AI-Generated", "Real"]
val_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
########################################
# 4. Define the Inference Function
########################################
def predict_image(img: Image.Image) -> str:
img = img.convert("RGB")
img_tensor = val_transforms(img).unsqueeze(0).to(device)
with torch.no_grad():
obj_logits, bin_logits = model(img_tensor)
obj_pred = torch.argmax(obj_logits, dim=1).item()
bin_pred = torch.argmax(bin_logits, dim=1).item()
obj_name = idx_to_obj_label.get(obj_pred, "Unknown")
bin_name = bin_label_names[bin_pred]
return f"Prediction: {obj_name} ({bin_name})"
########################################
# 5. Create Gradio UI
########################################
demo = gr.Interface(
fn=predict_image,
inputs=gr.Image(type="pil"),
outputs="text",
title="Multi-Task Image Classifier",
description="Upload an image to receive two predictions:\n1) The primary object in the image,\n2) Whether the image is AI-generated or Real."
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", share=True)