reidddd commited on
Commit
e3eebda
·
1 Parent(s): 032d442

update model file and improved accuracy

Browse files
Files changed (1) hide show
  1. app.py +27 -3
app.py CHANGED
@@ -10,6 +10,8 @@ import os
10
  import requests
11
  import gdown
12
  from skimage import io
 
 
13
 
14
  # Initialize Flask app
15
  app = Flask(__name__)
@@ -90,12 +92,34 @@ def upload():
90
  instances = outputs["instances"].to("cpu")
91
  class_names = MetadataCatalog.get(cfg.DATASETS.TEST[0]).thing_classes
92
 
93
- # Initialize total cost
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  total_cost = 0
95
  damage_details = []
96
 
97
- for j in range(len(instances)):
98
- class_id = instances.pred_classes[j].item()
99
  damaged_part = (
100
  class_names[class_id] if class_id < len(class_names) else "unknown"
101
  )
 
10
  import requests
11
  import gdown
12
  from skimage import io
13
+ from torchvision.ops import box_iou
14
+ import torch
15
 
16
  # Initialize Flask app
17
  app = Flask(__name__)
 
92
  instances = outputs["instances"].to("cpu")
93
  class_names = MetadataCatalog.get(cfg.DATASETS.TEST[0]).thing_classes
94
 
95
+ # Extract bounding boxes and class IDs
96
+ boxes = instances.pred_boxes.tensor.numpy()
97
+ class_ids = instances.pred_classes.numpy()
98
+
99
+ # Filter overlapping boxes using IoU
100
+ iou_threshold = 0.5
101
+ keep_indices = []
102
+ merged_boxes = set()
103
+
104
+ for i in range(len(boxes)):
105
+ if i in merged_boxes:
106
+ continue
107
+ keep_indices.append(i)
108
+ for j in range(i + 1, len(boxes)):
109
+ if j in merged_boxes:
110
+ continue
111
+ iou = box_iou(
112
+ torch.tensor(boxes[i]).unsqueeze(0), torch.tensor(boxes[j]).unsqueeze(0)
113
+ ).item()
114
+ if iou > iou_threshold:
115
+ merged_boxes.add(j)
116
+
117
+ # Calculate total cost based on non-overlapping boxes
118
  total_cost = 0
119
  damage_details = []
120
 
121
+ for idx in keep_indices:
122
+ class_id = class_ids[idx]
123
  damaged_part = (
124
  class_names[class_id] if class_id < len(class_names) else "unknown"
125
  )