Abdu07 commited on
Commit
a94f8d1
·
verified ·
1 Parent(s): f3d6a41

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -0
app.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchvision.models as models
5
+ import torchvision.transforms as transforms
6
+ from PIL import Image
7
+ from huggingface_hub import hf_hub_download
8
+
9
+ ########################################
10
+ # 1. Define the Model Architecture
11
+ ########################################
12
+ class MultiTaskModel(nn.Module):
13
+ def __init__(self, backbone, feature_dim, num_obj_classes):
14
+ super(MultiTaskModel, self).__init__()
15
+ self.backbone = backbone
16
+ # Object recognition head
17
+ self.obj_head = nn.Linear(feature_dim, num_obj_classes)
18
+ # Binary classification head (0: AI-generated, 1: Real)
19
+ self.bin_head = nn.Linear(feature_dim, 2)
20
+
21
+ def forward(self, x):
22
+ feats = self.backbone(x)
23
+ obj_logits = self.obj_head(feats)
24
+ bin_logits = self.bin_head(feats)
25
+ return obj_logits, bin_logits
26
+
27
+ ########################################
28
+ # 2. Reconstruct the Model and Load Weights
29
+ ########################################
30
+ # Set the number of object classes (update this to match your training)
31
+ num_obj_classes = 139 # Example value; change it to your actual number
32
+
33
+ device = torch.device("cpu")
34
+
35
+ # Instantiate the backbone: a ResNet-50 with its final layer removed.
36
+ resnet = models.resnet50(pretrained=False)
37
+ resnet.fc = nn.Identity() # Remove final classification layer
38
+ feature_dim = 2048
39
+
40
+ # Build the model architecture.
41
+ model = MultiTaskModel(resnet, feature_dim, num_obj_classes)
42
+ model.to(device)
43
+
44
+ # Download the state dict from HF Hub.
45
+ repo_id = "Abdu07/multitask-model" # Your repo name
46
+ filename = "multitask_model_weights.pth" # The state dict file you uploaded
47
+ weights_path = hf_hub_download(repo_id=repo_id, filename=filename)
48
+
49
+ # Load the state dict and update the model.
50
+ state_dict = torch.load(weights_path, map_location="cpu")
51
+ model.load_state_dict(state_dict)
52
+ model.eval()
53
+
54
+ ########################################
55
+ # 3. Define Label Mappings and Transforms
56
+ ########################################
57
+ # Update these with your actual label mappings.
58
+ idx_to_obj_label = {
59
+ 0: "cat",
60
+ 1: "dog",
61
+ 2: "car",
62
+ # ... add the rest of your object classes ...
63
+ }
64
+ bin_label_names = ["AI-Generated", "Real"]
65
+
66
+ # Define the validation transforms (must match those used during training)
67
+ val_transforms = transforms.Compose([
68
+ transforms.Resize(256),
69
+ transforms.CenterCrop(224),
70
+ transforms.ToTensor(),
71
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
72
+ std=[0.229, 0.224, 0.225])
73
+ ])
74
+
75
+ ########################################
76
+ # 4. Define the Inference Function
77
+ ########################################
78
+ def predict_image(img: Image.Image) -> str:
79
+ """
80
+ Takes an uploaded PIL image, processes it, and returns the model's prediction.
81
+ """
82
+ # Ensure the image is in RGB mode.
83
+ img = img.convert("RGB")
84
+ # Apply validation transforms.
85
+ img_tensor = val_transforms(img).unsqueeze(0).to(device) # Shape: [1, 3, 224, 224]
86
+ with torch.no_grad():
87
+ obj_logits, bin_logits = model(img_tensor)
88
+ obj_pred = torch.argmax(obj_logits, dim=1).item()
89
+ bin_pred = torch.argmax(bin_logits, dim=1).item()
90
+ obj_name = idx_to_obj_label.get(obj_pred, "Unknown")
91
+ bin_name = bin_label_names[bin_pred]
92
+ return f"Prediction: {obj_name} ({bin_name})"
93
+
94
+ ########################################
95
+ # 5. Create Gradio UI
96
+ ########################################
97
+ demo = gr.Interface(
98
+ fn=predict_image,
99
+ inputs=gr.Image(type="pil"),
100
+ outputs="text",
101
+ title="Multi-Task Image Classifier",
102
+ description=(
103
+ "Upload an image to receive two predictions:\n"
104
+ "1) The primary object in the image,\n"
105
+ "2) Whether the image is AI-generated or Real."
106
+ )
107
+ )
108
+
109
+ if __name__ == "__main__":
110
+ demo.launch(server_name="0.0.0.0", share=True)