File size: 4,007 Bytes
a94f8d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4fc801a
 
a94f8d1
 
 
 
 
 
 
 
 
 
 
 
 
 
4fc801a
a94f8d1
 
 
 
 
 
 
 
 
 
 
4fc801a
 
 
 
a94f8d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import gradio as gr
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from huggingface_hub import hf_hub_download

########################################
# 1. Define the Model Architecture
########################################
class MultiTaskModel(nn.Module):
    def __init__(self, backbone, feature_dim, num_obj_classes):
        super(MultiTaskModel, self).__init__()
        self.backbone = backbone
        # Object recognition head
        self.obj_head = nn.Linear(feature_dim, num_obj_classes)
        # Binary classification head (0: AI-generated, 1: Real)
        self.bin_head = nn.Linear(feature_dim, 2)
    
    def forward(self, x):
        feats = self.backbone(x)
        obj_logits = self.obj_head(feats)
        bin_logits = self.bin_head(feats)
        return obj_logits, bin_logits

########################################
# 2. Reconstruct the Model and Load Weights
########################################
# IMPORTANT: Use the same number of object classes as in training.
num_obj_classes = 494  # Updated to match the state dict from training

device = torch.device("cpu")

# Instantiate the backbone: a ResNet-50 with its final layer removed.
resnet = models.resnet50(pretrained=False)
resnet.fc = nn.Identity()  # Remove final classification layer
feature_dim = 2048

# Build the model architecture.
model = MultiTaskModel(resnet, feature_dim, num_obj_classes)
model.to(device)

# Download the state dict from HF Hub.
repo_id = "Abdu07/multitask-model"  # Your repo name
filename = "Yolloplusclassproject_weights.pth"  # The state dict file you uploaded
weights_path = hf_hub_download(repo_id=repo_id, filename=filename)

# Load the state dict and update the model.
state_dict = torch.load(weights_path, map_location="cpu")
model.load_state_dict(state_dict)
model.eval()

########################################
# 3. Define Label Mappings and Transforms
########################################
# Update these with your actual label mappings.
# They should reflect the 494 unique pseudo-labels produced during training.
# For this example, we assume that the mapping is stored somewhere.
# Here we provide a dummy mapping for illustration. Replace it with your real mapping.
idx_to_obj_label = {i: f"label_{i}" for i in range(num_obj_classes)}
bin_label_names = ["AI-Generated", "Real"]

# Define the validation transforms (must match those used during training)
val_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

########################################
# 4. Define the Inference Function
########################################
def predict_image(img: Image.Image) -> str:
    """
    Takes an uploaded PIL image, processes it, and returns the model's prediction.
    """
    # Ensure the image is in RGB mode.
    img = img.convert("RGB")
    # Apply validation transforms.
    img_tensor = val_transforms(img).unsqueeze(0).to(device)  # Shape: [1, 3, 224, 224]
    with torch.no_grad():
        obj_logits, bin_logits = model(img_tensor)
    obj_pred = torch.argmax(obj_logits, dim=1).item()
    bin_pred = torch.argmax(bin_logits, dim=1).item()
    obj_name = idx_to_obj_label.get(obj_pred, "Unknown")
    bin_name = bin_label_names[bin_pred]
    return f"Prediction: {obj_name} ({bin_name})"

########################################
# 5. Create Gradio UI
########################################
demo = gr.Interface(
    fn=predict_image,
    inputs=gr.Image(type="pil"),
    outputs="text",
    title="Multi-Task Image Classifier",
    description=(
        "Upload an image to receive two predictions:\n"
        "1) The primary object in the image,\n"
        "2) Whether the image is AI-generated or Real."
    )
)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", share=True)