Abdu07 commited on
Commit
b5a7c3b
·
verified ·
1 Parent(s): 220eb39

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -49
app.py CHANGED
@@ -1,96 +1,107 @@
1
  import gradio as gr
2
  import torch
3
  import torch.nn as nn
4
- import torchvision.transforms as T
 
5
  from PIL import Image
6
  from huggingface_hub import hf_hub_download
7
 
8
- #####################################
9
- # 1) Define the same custom class
10
- #####################################
11
  class MultiTaskModel(nn.Module):
12
  def __init__(self, backbone, feature_dim, num_obj_classes):
13
  super(MultiTaskModel, self).__init__()
14
  self.backbone = backbone
 
15
  self.obj_head = nn.Linear(feature_dim, num_obj_classes)
 
16
  self.bin_head = nn.Linear(feature_dim, 2)
17
-
18
  def forward(self, x):
19
  feats = self.backbone(x)
20
  obj_logits = self.obj_head(feats)
21
  bin_logits = self.bin_head(feats)
22
  return obj_logits, bin_logits
23
 
24
- #####################################
25
- # 2) Allowlist the class
26
- #####################################
27
- import torch.serialization
28
- torch.serialization.add_safe_globals([MultiTaskModel])
 
29
 
30
- #####################################
31
- # 3) Download & Load the full model
32
- #####################################
33
- repo_id = "Abdu07/multitask-model" # or your actual repo
34
- filename = "multitask_model.pth" # the file you uploaded
35
- model_path = hf_hub_download(repo_id=repo_id, filename=filename)
 
 
36
 
37
- # Force PyTorch to load the full model object
38
- model = torch.load(model_path, map_location="cpu") # default weights_only=True, but we added safe_globals
 
 
 
 
39
  model.eval()
40
 
41
- #####################################
42
- # 4) Label Mappings
43
- #####################################
 
44
  idx_to_obj_label = {
45
- # Fill in with your actual object label indices
46
  0: "cat",
47
  1: "dog",
48
  2: "car",
49
- # ...
50
  }
51
  bin_label_names = ["AI-Generated", "Real"]
52
 
53
- #####################################
54
- # 5) Validation Transforms
55
- #####################################
56
- val_transforms = T.Compose([
57
- T.Resize(256),
58
- T.CenterCrop(224),
59
- T.ToTensor(),
60
- T.Normalize(mean=[0.485, 0.456, 0.406],
61
- std=[0.229, 0.224, 0.225])
62
  ])
63
 
64
- #####################################
65
- # 6) Inference Function
66
- #####################################
67
  def predict_image(img: Image.Image) -> str:
 
 
 
 
68
  img = img.convert("RGB")
69
- img_t = val_transforms(img).unsqueeze(0)
 
70
  with torch.no_grad():
71
- obj_logits, bin_logits = model(img_t)
72
- obj_pred = torch.argmax(obj_logits, dim=1).item()
73
- bin_pred = torch.argmax(bin_logits, dim=1).item()
74
  obj_name = idx_to_obj_label.get(obj_pred, "Unknown")
75
  bin_name = bin_label_names[bin_pred]
76
- return f"Object: {obj_name} | Authenticity: {bin_name}"
77
 
78
- #####################################
79
- # 7) Gradio UI
80
- #####################################
81
  demo = gr.Interface(
82
  fn=predict_image,
83
  inputs=gr.Image(type="pil"),
84
  outputs="text",
85
  title="Multi-Task Image Classifier",
86
  description=(
87
- "Upload an image to get two predictions: "
88
- "1) The primary object, 2) Whether the image is AI-generated or real."
 
89
  )
90
  )
91
 
92
- def main():
93
- demo.launch(server_name="0.0.0.0", enable_queue=True)
94
-
95
  if __name__ == "__main__":
96
- main()
 
1
  import gradio as gr
2
  import torch
3
  import torch.nn as nn
4
+ import torchvision.models as models
5
+ import torchvision.transforms as transforms
6
  from PIL import Image
7
  from huggingface_hub import hf_hub_download
8
 
9
+ ########################################
10
+ # 1. Define Model Architecture
11
+ ########################################
12
  class MultiTaskModel(nn.Module):
13
  def __init__(self, backbone, feature_dim, num_obj_classes):
14
  super(MultiTaskModel, self).__init__()
15
  self.backbone = backbone
16
+ # Object recognition head
17
  self.obj_head = nn.Linear(feature_dim, num_obj_classes)
18
+ # Binary classification head (0: AI-generated, 1: Real)
19
  self.bin_head = nn.Linear(feature_dim, 2)
20
+
21
  def forward(self, x):
22
  feats = self.backbone(x)
23
  obj_logits = self.obj_head(feats)
24
  bin_logits = self.bin_head(feats)
25
  return obj_logits, bin_logits
26
 
27
+ ########################################
28
+ # 2. Reconstruct the Model and Load Weights
29
+ ########################################
30
+ # Set the number of object classes (update this to match your training)
31
+ num_obj_classes = 139 # Example; change as needed
32
+ device = torch.device("cpu")
33
 
34
+ # Instantiate the backbone (ResNet-50 without its final layer)
35
+ resnet = models.resnet50(pretrained=False)
36
+ resnet.fc = nn.Identity()
37
+ feature_dim = 2048
38
+
39
+ # Build the model architecture
40
+ model = MultiTaskModel(resnet, feature_dim, num_obj_classes)
41
+ model.to(device)
42
 
43
+ # Download the state dict from HF Hub
44
+ repo_id = "Abdu07/multitask-model"
45
+ filename = "best_model.pt" # Make sure this is the state dict file
46
+ model_path = hf_hub_download(repo_id=repo_id, filename=filename)
47
+ state_dict = torch.load(model_path, map_location="cpu")
48
+ model.load_state_dict(state_dict)
49
  model.eval()
50
 
51
+ ########################################
52
+ # 3. Define Label Mappings and Transforms
53
+ ########################################
54
+ # Update these mappings with your actual training labels.
55
  idx_to_obj_label = {
 
56
  0: "cat",
57
  1: "dog",
58
  2: "car",
59
+ # ... add your object classes here ...
60
  }
61
  bin_label_names = ["AI-Generated", "Real"]
62
 
63
+ # Define the validation transforms (same as used during training/validation)
64
+ val_transforms = transforms.Compose([
65
+ transforms.Resize(256),
66
+ transforms.CenterCrop(224),
67
+ transforms.ToTensor(),
68
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
69
+ std=[0.229, 0.224, 0.225])
 
 
70
  ])
71
 
72
+ ########################################
73
+ # 4. Define the Inference Function
74
+ ########################################
75
  def predict_image(img: Image.Image) -> str:
76
+ """
77
+ Takes an uploaded PIL image, processes it, and returns the model's prediction.
78
+ """
79
+ # Ensure image is in RGB
80
  img = img.convert("RGB")
81
+ # Apply validation transforms
82
+ img_tensor = val_transforms(img).unsqueeze(0).to(device) # Shape: [1, 3, 224, 224]
83
  with torch.no_grad():
84
+ obj_logits, bin_logits = model(img_tensor)
85
+ obj_pred = torch.argmax(obj_logits, dim=1).item()
86
+ bin_pred = torch.argmax(bin_logits, dim=1).item()
87
  obj_name = idx_to_obj_label.get(obj_pred, "Unknown")
88
  bin_name = bin_label_names[bin_pred]
89
+ return f"Prediction: {obj_name} ({bin_name})"
90
 
91
+ ########################################
92
+ # 5. Create Gradio UI
93
+ ########################################
94
  demo = gr.Interface(
95
  fn=predict_image,
96
  inputs=gr.Image(type="pil"),
97
  outputs="text",
98
  title="Multi-Task Image Classifier",
99
  description=(
100
+ "Upload an image to receive two predictions:\n"
101
+ "1) The primary object in the image,\n"
102
+ "2) Whether the image is AI-generated or Real."
103
  )
104
  )
105
 
 
 
 
106
  if __name__ == "__main__":
107
+ demo.launch(server_name="0.0.0.0", share=True)