Spaces:
Running
Running
File size: 4,007 Bytes
a94f8d1 4fc801a a94f8d1 4fc801a a94f8d1 4fc801a a94f8d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import gradio as gr
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from huggingface_hub import hf_hub_download
########################################
# 1. Define the Model Architecture
########################################
class MultiTaskModel(nn.Module):
def __init__(self, backbone, feature_dim, num_obj_classes):
super(MultiTaskModel, self).__init__()
self.backbone = backbone
# Object recognition head
self.obj_head = nn.Linear(feature_dim, num_obj_classes)
# Binary classification head (0: AI-generated, 1: Real)
self.bin_head = nn.Linear(feature_dim, 2)
def forward(self, x):
feats = self.backbone(x)
obj_logits = self.obj_head(feats)
bin_logits = self.bin_head(feats)
return obj_logits, bin_logits
########################################
# 2. Reconstruct the Model and Load Weights
########################################
# IMPORTANT: Use the same number of object classes as in training.
num_obj_classes = 494 # Updated to match the state dict from training
device = torch.device("cpu")
# Instantiate the backbone: a ResNet-50 with its final layer removed.
resnet = models.resnet50(pretrained=False)
resnet.fc = nn.Identity() # Remove final classification layer
feature_dim = 2048
# Build the model architecture.
model = MultiTaskModel(resnet, feature_dim, num_obj_classes)
model.to(device)
# Download the state dict from HF Hub.
repo_id = "Abdu07/multitask-model" # Your repo name
filename = "Yolloplusclassproject_weights.pth" # The state dict file you uploaded
weights_path = hf_hub_download(repo_id=repo_id, filename=filename)
# Load the state dict and update the model.
state_dict = torch.load(weights_path, map_location="cpu")
model.load_state_dict(state_dict)
model.eval()
########################################
# 3. Define Label Mappings and Transforms
########################################
# Update these with your actual label mappings.
# They should reflect the 494 unique pseudo-labels produced during training.
# For this example, we assume that the mapping is stored somewhere.
# Here we provide a dummy mapping for illustration. Replace it with your real mapping.
idx_to_obj_label = {i: f"label_{i}" for i in range(num_obj_classes)}
bin_label_names = ["AI-Generated", "Real"]
# Define the validation transforms (must match those used during training)
val_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
########################################
# 4. Define the Inference Function
########################################
def predict_image(img: Image.Image) -> str:
"""
Takes an uploaded PIL image, processes it, and returns the model's prediction.
"""
# Ensure the image is in RGB mode.
img = img.convert("RGB")
# Apply validation transforms.
img_tensor = val_transforms(img).unsqueeze(0).to(device) # Shape: [1, 3, 224, 224]
with torch.no_grad():
obj_logits, bin_logits = model(img_tensor)
obj_pred = torch.argmax(obj_logits, dim=1).item()
bin_pred = torch.argmax(bin_logits, dim=1).item()
obj_name = idx_to_obj_label.get(obj_pred, "Unknown")
bin_name = bin_label_names[bin_pred]
return f"Prediction: {obj_name} ({bin_name})"
########################################
# 5. Create Gradio UI
########################################
demo = gr.Interface(
fn=predict_image,
inputs=gr.Image(type="pil"),
outputs="text",
title="Multi-Task Image Classifier",
description=(
"Upload an image to receive two predictions:\n"
"1) The primary object in the image,\n"
"2) Whether the image is AI-generated or Real."
)
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", share=True)
|