Abdu07 commited on
Commit
5d0efae
·
verified ·
1 Parent(s): f1e05c5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -0
app.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+ import requests
7
+ from huggingface_hub import hf_hub_download
8
+
9
+ ########################
10
+ # 1) Download & Load Model
11
+ ########################
12
+
13
+ # Replace with your actual model repo on HF
14
+ repo_id = "Abdu07/multitask-model"
15
+ filename = "multitask_model.pth"
16
+
17
+ # Download the model file from the Hub
18
+ model_path = hf_hub_download(repo_id=repo_id, filename=filename)
19
+ model = torch.load(model_path, map_location="cpu") # or map_location="cuda" if you prefer
20
+ model.eval()
21
+
22
+ ########################
23
+ # 2) Define Label Mappings
24
+ ########################
25
+
26
+ # For example, if your object labels are saved in code:
27
+ idx_to_obj_label = {
28
+ 0: "cat",
29
+ 1: "dog",
30
+ 2: "car",
31
+ # ... fill in all your categories ...
32
+ }
33
+
34
+ bin_label_names = ["AI-Generated", "Real"] # Adjust if 0=AI, 1=Real
35
+
36
+ ########################
37
+ # 3) Define Transforms
38
+ ########################
39
+
40
+ # Match the transforms you used during validation
41
+ val_transforms = transforms.Compose([
42
+ transforms.Resize(256),
43
+ transforms.CenterCrop(224),
44
+ transforms.ToTensor(),
45
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
46
+ std=[0.229, 0.224, 0.225])
47
+ ])
48
+
49
+ ########################
50
+ # 4) Define the Inference Function
51
+ ########################
52
+
53
+ def predict_image(img: Image.Image) -> str:
54
+ """
55
+ Takes a PIL image, applies transforms, passes through the model,
56
+ and returns the combined prediction (object + AI/Real).
57
+ """
58
+ # Convert to RGB just in case
59
+ img = img.convert("RGB")
60
+
61
+ # Apply transforms
62
+ img_t = val_transforms(img)
63
+ # Add batch dimension
64
+ img_t = img_t.unsqueeze(0)
65
+
66
+ with torch.no_grad():
67
+ obj_logits, bin_logits = model(img_t)
68
+ obj_pred = torch.argmax(obj_logits, dim=1).item()
69
+ bin_pred = torch.argmax(bin_logits, dim=1).item()
70
+
71
+ # Map predictions to labels
72
+ obj_name = idx_to_obj_label.get(obj_pred, "Unknown")
73
+ bin_name = bin_label_names[bin_pred]
74
+
75
+ return f"Object: {obj_name} | Authenticity: {bin_name}"
76
+
77
+ ########################
78
+ # 5) Build Gradio UI
79
+ ########################
80
+
81
+ demo = gr.Interface(
82
+ fn=predict_image,
83
+ inputs=gr.Image(type="pil"),
84
+ outputs="text",
85
+ title="Multi-Task Image Classifier",
86
+ description=(
87
+ "Upload an image to get two predictions: "
88
+ "1) The primary object (from pseudo-labeling), "
89
+ "2) Whether the image is AI-generated or real."
90
+ )
91
+ )
92
+
93
+ ########################
94
+ # 6) Launch the App
95
+ ########################
96
+
97
+ def main():
98
+ demo.launch(server_name="0.0.0.0", enable_queue=True)
99
+
100
+ if __name__ == "__main__":
101
+ main()