Abdu07 commited on
Commit
526455e
·
verified ·
1 Parent(s): 4fc801a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -32
app.py CHANGED
@@ -5,6 +5,7 @@ import torchvision.models as models
5
  import torchvision.transforms as transforms
6
  from PIL import Image
7
  from huggingface_hub import hf_hub_download
 
8
 
9
  ########################################
10
  # 1. Define the Model Architecture
@@ -13,9 +14,7 @@ class MultiTaskModel(nn.Module):
13
  def __init__(self, backbone, feature_dim, num_obj_classes):
14
  super(MultiTaskModel, self).__init__()
15
  self.backbone = backbone
16
- # Object recognition head
17
  self.obj_head = nn.Linear(feature_dim, num_obj_classes)
18
- # Binary classification head (0: AI-generated, 1: Real)
19
  self.bin_head = nn.Linear(feature_dim, 2)
20
 
21
  def forward(self, x):
@@ -27,41 +26,33 @@ class MultiTaskModel(nn.Module):
27
  ########################################
28
  # 2. Reconstruct the Model and Load Weights
29
  ########################################
30
- # IMPORTANT: Use the same number of object classes as in training.
31
- num_obj_classes = 494 # Updated to match the state dict from training
32
-
33
  device = torch.device("cpu")
34
 
35
- # Instantiate the backbone: a ResNet-50 with its final layer removed.
36
  resnet = models.resnet50(pretrained=False)
37
- resnet.fc = nn.Identity() # Remove final classification layer
38
  feature_dim = 2048
39
-
40
- # Build the model architecture.
41
  model = MultiTaskModel(resnet, feature_dim, num_obj_classes)
42
  model.to(device)
43
 
44
- # Download the state dict from HF Hub.
45
- repo_id = "Abdu07/multitask-model" # Your repo name
46
- filename = "Yolloplusclassproject_weights.pth" # The state dict file you uploaded
47
  weights_path = hf_hub_download(repo_id=repo_id, filename=filename)
48
-
49
- # Load the state dict and update the model.
50
  state_dict = torch.load(weights_path, map_location="cpu")
51
  model.load_state_dict(state_dict)
52
  model.eval()
53
 
54
  ########################################
55
- # 3. Define Label Mappings and Transforms
56
  ########################################
57
- # Update these with your actual label mappings.
58
- # They should reflect the 494 unique pseudo-labels produced during training.
59
- # For this example, we assume that the mapping is stored somewhere.
60
- # Here we provide a dummy mapping for illustration. Replace it with your real mapping.
61
- idx_to_obj_label = {i: f"label_{i}" for i in range(num_obj_classes)}
 
62
  bin_label_names = ["AI-Generated", "Real"]
63
 
64
- # Define the validation transforms (must match those used during training)
65
  val_transforms = transforms.Compose([
66
  transforms.Resize(256),
67
  transforms.CenterCrop(224),
@@ -74,13 +65,8 @@ val_transforms = transforms.Compose([
74
  # 4. Define the Inference Function
75
  ########################################
76
  def predict_image(img: Image.Image) -> str:
77
- """
78
- Takes an uploaded PIL image, processes it, and returns the model's prediction.
79
- """
80
- # Ensure the image is in RGB mode.
81
  img = img.convert("RGB")
82
- # Apply validation transforms.
83
- img_tensor = val_transforms(img).unsqueeze(0).to(device) # Shape: [1, 3, 224, 224]
84
  with torch.no_grad():
85
  obj_logits, bin_logits = model(img_tensor)
86
  obj_pred = torch.argmax(obj_logits, dim=1).item()
@@ -97,11 +83,7 @@ demo = gr.Interface(
97
  inputs=gr.Image(type="pil"),
98
  outputs="text",
99
  title="Multi-Task Image Classifier",
100
- description=(
101
- "Upload an image to receive two predictions:\n"
102
- "1) The primary object in the image,\n"
103
- "2) Whether the image is AI-generated or Real."
104
- )
105
  )
106
 
107
  if __name__ == "__main__":
 
5
  import torchvision.transforms as transforms
6
  from PIL import Image
7
  from huggingface_hub import hf_hub_download
8
+ import json
9
 
10
  ########################################
11
  # 1. Define the Model Architecture
 
14
  def __init__(self, backbone, feature_dim, num_obj_classes):
15
  super(MultiTaskModel, self).__init__()
16
  self.backbone = backbone
 
17
  self.obj_head = nn.Linear(feature_dim, num_obj_classes)
 
18
  self.bin_head = nn.Linear(feature_dim, 2)
19
 
20
  def forward(self, x):
 
26
  ########################################
27
  # 2. Reconstruct the Model and Load Weights
28
  ########################################
29
+ num_obj_classes = 494 # Make sure this matches your training
 
 
30
  device = torch.device("cpu")
31
 
 
32
  resnet = models.resnet50(pretrained=False)
33
+ resnet.fc = nn.Identity()
34
  feature_dim = 2048
 
 
35
  model = MultiTaskModel(resnet, feature_dim, num_obj_classes)
36
  model.to(device)
37
 
38
+ repo_id = "Abdu07/multitask-model"
39
+ filename = "Yolloplusclassproject_weights.pth"
 
40
  weights_path = hf_hub_download(repo_id=repo_id, filename=filename)
 
 
41
  state_dict = torch.load(weights_path, map_location="cpu")
42
  model.load_state_dict(state_dict)
43
  model.eval()
44
 
45
  ########################################
46
+ # 3. Load Label Mapping and Define Transforms
47
  ########################################
48
+ # Load the saved mapping from JSON
49
+ with open("obj_label_mapping.json", "r") as f:
50
+ obj_label_to_idx = json.load(f)
51
+ # Create the inverse mapping
52
+ idx_to_obj_label = {v: k for k, v in obj_label_to_idx.items()}
53
+
54
  bin_label_names = ["AI-Generated", "Real"]
55
 
 
56
  val_transforms = transforms.Compose([
57
  transforms.Resize(256),
58
  transforms.CenterCrop(224),
 
65
  # 4. Define the Inference Function
66
  ########################################
67
  def predict_image(img: Image.Image) -> str:
 
 
 
 
68
  img = img.convert("RGB")
69
+ img_tensor = val_transforms(img).unsqueeze(0).to(device)
 
70
  with torch.no_grad():
71
  obj_logits, bin_logits = model(img_tensor)
72
  obj_pred = torch.argmax(obj_logits, dim=1).item()
 
83
  inputs=gr.Image(type="pil"),
84
  outputs="text",
85
  title="Multi-Task Image Classifier",
86
+ description="Upload an image to receive two predictions:\n1) The primary object in the image,\n2) Whether the image is AI-generated or Real."
 
 
 
 
87
  )
88
 
89
  if __name__ == "__main__":