Abdu07 commited on
Commit
4375fb7
·
verified ·
1 Parent(s): 5963401

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -5
app.py CHANGED
@@ -26,7 +26,10 @@ class MultiTaskModel(nn.Module):
26
  ########################################
27
  # 2. Reconstruct the Model and Load Weights
28
  ########################################
29
- num_obj_classes = 494 # Make sure this matches your training
 
 
 
30
  device = torch.device("cpu")
31
 
32
  resnet = models.resnet50(pretrained=False)
@@ -45,10 +48,17 @@ model.eval()
45
  ########################################
46
  # 3. Load Label Mapping and Define Transforms
47
  ########################################
48
- # Load the saved mapping from JSON
49
- with open("obj_label_mapping.json", "r") as f:
50
- obj_label_to_idx = json.load(f)
51
- # Create the inverse mapping
 
 
 
 
 
 
 
52
  idx_to_obj_label = {v: k for k, v in obj_label_to_idx.items()}
53
 
54
  bin_label_names = ["AI-Generated", "Real"]
 
26
  ########################################
27
  # 2. Reconstruct the Model and Load Weights
28
  ########################################
29
+ # IMPORTANT: The checkpoint was saved with a single object class,
30
+ # so we set num_obj_classes to 1.
31
+ num_obj_classes = 1
32
+
33
  device = torch.device("cpu")
34
 
35
  resnet = models.resnet50(pretrained=False)
 
48
  ########################################
49
  # 3. Load Label Mapping and Define Transforms
50
  ########################################
51
+ # Attempt to load the mapping from JSON.
52
+ # If the mapping contains more than one label, we override it with a single-label mapping
53
+ try:
54
+ with open("obj_label_mapping.json", "r") as f:
55
+ obj_label_to_idx = json.load(f)
56
+ if len(obj_label_to_idx) != 1:
57
+ obj_label_to_idx = {"Detected Object": 0}
58
+ except Exception as e:
59
+ print("Error loading mapping, using default mapping. Error:", e)
60
+ obj_label_to_idx = {"Detected Object": 0}
61
+
62
  idx_to_obj_label = {v: k for k, v in obj_label_to_idx.items()}
63
 
64
  bin_label_names = ["AI-Generated", "Real"]