Abdu07 commited on
Commit
7ffe08a
·
verified ·
1 Parent(s): eba6774

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -32
app.py CHANGED
@@ -24,21 +24,35 @@ class MultiTaskModel(nn.Module):
24
  return obj_logits, bin_logits
25
 
26
  ########################################
27
- # 2. Load the Label Mapping and Set num_obj_classes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  ########################################
29
  # Load the saved mapping from JSON
30
  with open("obj_label_mapping.json", "r") as f:
31
  obj_label_to_idx = json.load(f)
32
- # Use the mapping as-is; do not override it.
33
- num_obj_classes = len(obj_label_to_idx)
34
  # Create the inverse mapping
35
  idx_to_obj_label = {v: k for k, v in obj_label_to_idx.items()}
36
 
37
  bin_label_names = ["AI-Generated", "Real"]
38
 
39
- ########################################
40
- # 3. Define Validation Transforms
41
- ########################################
42
  val_transforms = transforms.Compose([
43
  transforms.Resize(256),
44
  transforms.CenterCrop(224),
@@ -48,27 +62,7 @@ val_transforms = transforms.Compose([
48
  ])
49
 
50
  ########################################
51
- # 4. Reconstruct the Model and Load Weights
52
- ########################################
53
- device = torch.device("cpu")
54
-
55
- resnet = models.resnet50(pretrained=False)
56
- resnet.fc = nn.Identity()
57
- feature_dim = 2048
58
-
59
- # Build the model architecture.
60
- model = MultiTaskModel(resnet, feature_dim, num_obj_classes)
61
- model.to(device)
62
-
63
- repo_id = "Abdu07/multitask-model"
64
- filename = "DualSight.pth" # Ensure this checkpoint is from training with the same num_obj_classes
65
- weights_path = hf_hub_download(repo_id=repo_id, filename=filename)
66
- state_dict = torch.load(weights_path, map_location="cpu")
67
- model.load_state_dict(state_dict)
68
- model.eval()
69
-
70
- ########################################
71
- # 5. Define the Inference Function
72
  ########################################
73
  def predict_image(img: Image.Image) -> str:
74
  img = img.convert("RGB")
@@ -82,17 +76,15 @@ def predict_image(img: Image.Image) -> str:
82
  return f"Prediction: {obj_name} ({bin_name})"
83
 
84
  ########################################
85
- # 6. Create Gradio UI
86
  ########################################
87
  demo = gr.Interface(
88
  fn=predict_image,
89
  inputs=gr.Image(type="pil"),
90
  outputs="text",
91
- title="Multi-Task Image Classifier Trained by [Abdellahi El Moustapha](https://abmstpha.github.io/)",
92
  description="Upload an image to receive two predictions:\n1) The primary object in the image,\n2) Whether the image is AI-generated or Real."
93
  )
94
 
95
-
96
-
97
  if __name__ == "__main__":
98
- demo.launch(server_name="0.0.0.0", share=True)
 
24
  return obj_logits, bin_logits
25
 
26
  ########################################
27
+ # 2. Reconstruct the Model and Load Weights
28
+ ########################################
29
+ num_obj_classes = 494 # Make sure this matches your training
30
+ device = torch.device("cpu")
31
+
32
+ resnet = models.resnet50(pretrained=False)
33
+ resnet.fc = nn.Identity()
34
+ feature_dim = 2048
35
+ model = MultiTaskModel(resnet, feature_dim, num_obj_classes)
36
+ model.to(device)
37
+
38
+ repo_id = "Abdu07/multitask-model"
39
+ filename = "Yolloplusclassproject_weights.pth"
40
+ weights_path = hf_hub_download(repo_id=repo_id, filename=filename)
41
+ state_dict = torch.load(weights_path, map_location="cpu")
42
+ model.load_state_dict(state_dict)
43
+ model.eval()
44
+
45
+ ########################################
46
+ # 3. Load Label Mapping and Define Transforms
47
  ########################################
48
  # Load the saved mapping from JSON
49
  with open("obj_label_mapping.json", "r") as f:
50
  obj_label_to_idx = json.load(f)
 
 
51
  # Create the inverse mapping
52
  idx_to_obj_label = {v: k for k, v in obj_label_to_idx.items()}
53
 
54
  bin_label_names = ["AI-Generated", "Real"]
55
 
 
 
 
56
  val_transforms = transforms.Compose([
57
  transforms.Resize(256),
58
  transforms.CenterCrop(224),
 
62
  ])
63
 
64
  ########################################
65
+ # 4. Define the Inference Function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  ########################################
67
  def predict_image(img: Image.Image) -> str:
68
  img = img.convert("RGB")
 
76
  return f"Prediction: {obj_name} ({bin_name})"
77
 
78
  ########################################
79
+ # 5. Create Gradio UI
80
  ########################################
81
  demo = gr.Interface(
82
  fn=predict_image,
83
  inputs=gr.Image(type="pil"),
84
  outputs="text",
85
+ title="Multi-Task Image Classifier",
86
  description="Upload an image to receive two predictions:\n1) The primary object in the image,\n2) Whether the image is AI-generated or Real."
87
  )
88
 
 
 
89
  if __name__ == "__main__":
90
+ demo.launch(server_name="0.0.0.0", share=True)