Abdu07 commited on
Commit
4fc801a
·
verified ·
1 Parent(s): 2dfeda2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -9
app.py CHANGED
@@ -27,8 +27,8 @@ class MultiTaskModel(nn.Module):
27
  ########################################
28
  # 2. Reconstruct the Model and Load Weights
29
  ########################################
30
- # Set the number of object classes (update this to match your training)
31
- num_obj_classes = 139 # Example value; update as needed
32
 
33
  device = torch.device("cpu")
34
 
@@ -43,7 +43,7 @@ model.to(device)
43
 
44
  # Download the state dict from HF Hub.
45
  repo_id = "Abdu07/multitask-model" # Your repo name
46
- filename = "Yolloplusclassproject_weights.pth" # New weight file name
47
  weights_path = hf_hub_download(repo_id=repo_id, filename=filename)
48
 
49
  # Load the state dict and update the model.
@@ -55,12 +55,10 @@ model.eval()
55
  # 3. Define Label Mappings and Transforms
56
  ########################################
57
  # Update these with your actual label mappings.
58
- idx_to_obj_label = {
59
- 0: "cat",
60
- 1: "dog",
61
- 2: "car",
62
- # ... add the rest of your object classes ...
63
- }
64
  bin_label_names = ["AI-Generated", "Real"]
65
 
66
  # Define the validation transforms (must match those used during training)
 
27
  ########################################
28
  # 2. Reconstruct the Model and Load Weights
29
  ########################################
30
+ # IMPORTANT: Use the same number of object classes as in training.
31
+ num_obj_classes = 494 # Updated to match the state dict from training
32
 
33
  device = torch.device("cpu")
34
 
 
43
 
44
  # Download the state dict from HF Hub.
45
  repo_id = "Abdu07/multitask-model" # Your repo name
46
+ filename = "Yolloplusclassproject_weights.pth" # The state dict file you uploaded
47
  weights_path = hf_hub_download(repo_id=repo_id, filename=filename)
48
 
49
  # Load the state dict and update the model.
 
55
  # 3. Define Label Mappings and Transforms
56
  ########################################
57
  # Update these with your actual label mappings.
58
+ # They should reflect the 494 unique pseudo-labels produced during training.
59
+ # For this example, we assume that the mapping is stored somewhere.
60
+ # Here we provide a dummy mapping for illustration. Replace it with your real mapping.
61
+ idx_to_obj_label = {i: f"label_{i}" for i in range(num_obj_classes)}
 
 
62
  bin_label_names = ["AI-Generated", "Real"]
63
 
64
  # Define the validation transforms (must match those used during training)