DualSight-Demo / app.py
Abdu07's picture
Update app.py
2dfeda2 verified
raw
history blame
3.8 kB
import gradio as gr
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from huggingface_hub import hf_hub_download
########################################
# 1. Define the Model Architecture
########################################
class MultiTaskModel(nn.Module):
def __init__(self, backbone, feature_dim, num_obj_classes):
super(MultiTaskModel, self).__init__()
self.backbone = backbone
# Object recognition head
self.obj_head = nn.Linear(feature_dim, num_obj_classes)
# Binary classification head (0: AI-generated, 1: Real)
self.bin_head = nn.Linear(feature_dim, 2)
def forward(self, x):
feats = self.backbone(x)
obj_logits = self.obj_head(feats)
bin_logits = self.bin_head(feats)
return obj_logits, bin_logits
########################################
# 2. Reconstruct the Model and Load Weights
########################################
# Set the number of object classes (update this to match your training)
num_obj_classes = 139 # Example value; update as needed
device = torch.device("cpu")
# Instantiate the backbone: a ResNet-50 with its final layer removed.
resnet = models.resnet50(pretrained=False)
resnet.fc = nn.Identity() # Remove final classification layer
feature_dim = 2048
# Build the model architecture.
model = MultiTaskModel(resnet, feature_dim, num_obj_classes)
model.to(device)
# Download the state dict from HF Hub.
repo_id = "Abdu07/multitask-model" # Your repo name
filename = "Yolloplusclassproject_weights.pth" # New weight file name
weights_path = hf_hub_download(repo_id=repo_id, filename=filename)
# Load the state dict and update the model.
state_dict = torch.load(weights_path, map_location="cpu")
model.load_state_dict(state_dict)
model.eval()
########################################
# 3. Define Label Mappings and Transforms
########################################
# Update these with your actual label mappings.
idx_to_obj_label = {
0: "cat",
1: "dog",
2: "car",
# ... add the rest of your object classes ...
}
bin_label_names = ["AI-Generated", "Real"]
# Define the validation transforms (must match those used during training)
val_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
########################################
# 4. Define the Inference Function
########################################
def predict_image(img: Image.Image) -> str:
"""
Takes an uploaded PIL image, processes it, and returns the model's prediction.
"""
# Ensure the image is in RGB mode.
img = img.convert("RGB")
# Apply validation transforms.
img_tensor = val_transforms(img).unsqueeze(0).to(device) # Shape: [1, 3, 224, 224]
with torch.no_grad():
obj_logits, bin_logits = model(img_tensor)
obj_pred = torch.argmax(obj_logits, dim=1).item()
bin_pred = torch.argmax(bin_logits, dim=1).item()
obj_name = idx_to_obj_label.get(obj_pred, "Unknown")
bin_name = bin_label_names[bin_pred]
return f"Prediction: {obj_name} ({bin_name})"
########################################
# 5. Create Gradio UI
########################################
demo = gr.Interface(
fn=predict_image,
inputs=gr.Image(type="pil"),
outputs="text",
title="Multi-Task Image Classifier",
description=(
"Upload an image to receive two predictions:\n"
"1) The primary object in the image,\n"
"2) Whether the image is AI-generated or Real."
)
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", share=True)