DualSight-Demo / app.py
Abdu07's picture
Update app.py
4fc801a verified
raw
history blame
4.01 kB
import gradio as gr
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from huggingface_hub import hf_hub_download
########################################
# 1. Define the Model Architecture
########################################
class MultiTaskModel(nn.Module):
def __init__(self, backbone, feature_dim, num_obj_classes):
super(MultiTaskModel, self).__init__()
self.backbone = backbone
# Object recognition head
self.obj_head = nn.Linear(feature_dim, num_obj_classes)
# Binary classification head (0: AI-generated, 1: Real)
self.bin_head = nn.Linear(feature_dim, 2)
def forward(self, x):
feats = self.backbone(x)
obj_logits = self.obj_head(feats)
bin_logits = self.bin_head(feats)
return obj_logits, bin_logits
########################################
# 2. Reconstruct the Model and Load Weights
########################################
# IMPORTANT: Use the same number of object classes as in training.
num_obj_classes = 494 # Updated to match the state dict from training
device = torch.device("cpu")
# Instantiate the backbone: a ResNet-50 with its final layer removed.
resnet = models.resnet50(pretrained=False)
resnet.fc = nn.Identity() # Remove final classification layer
feature_dim = 2048
# Build the model architecture.
model = MultiTaskModel(resnet, feature_dim, num_obj_classes)
model.to(device)
# Download the state dict from HF Hub.
repo_id = "Abdu07/multitask-model" # Your repo name
filename = "Yolloplusclassproject_weights.pth" # The state dict file you uploaded
weights_path = hf_hub_download(repo_id=repo_id, filename=filename)
# Load the state dict and update the model.
state_dict = torch.load(weights_path, map_location="cpu")
model.load_state_dict(state_dict)
model.eval()
########################################
# 3. Define Label Mappings and Transforms
########################################
# Update these with your actual label mappings.
# They should reflect the 494 unique pseudo-labels produced during training.
# For this example, we assume that the mapping is stored somewhere.
# Here we provide a dummy mapping for illustration. Replace it with your real mapping.
idx_to_obj_label = {i: f"label_{i}" for i in range(num_obj_classes)}
bin_label_names = ["AI-Generated", "Real"]
# Define the validation transforms (must match those used during training)
val_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
########################################
# 4. Define the Inference Function
########################################
def predict_image(img: Image.Image) -> str:
"""
Takes an uploaded PIL image, processes it, and returns the model's prediction.
"""
# Ensure the image is in RGB mode.
img = img.convert("RGB")
# Apply validation transforms.
img_tensor = val_transforms(img).unsqueeze(0).to(device) # Shape: [1, 3, 224, 224]
with torch.no_grad():
obj_logits, bin_logits = model(img_tensor)
obj_pred = torch.argmax(obj_logits, dim=1).item()
bin_pred = torch.argmax(bin_logits, dim=1).item()
obj_name = idx_to_obj_label.get(obj_pred, "Unknown")
bin_name = bin_label_names[bin_pred]
return f"Prediction: {obj_name} ({bin_name})"
########################################
# 5. Create Gradio UI
########################################
demo = gr.Interface(
fn=predict_image,
inputs=gr.Image(type="pil"),
outputs="text",
title="Multi-Task Image Classifier",
description=(
"Upload an image to receive two predictions:\n"
"1) The primary object in the image,\n"
"2) Whether the image is AI-generated or Real."
)
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", share=True)