oxkitsune commited on
Commit
1d91ad9
·
1 Parent(s): 976cc78

unique color per class

Browse files
Files changed (1) hide show
  1. app.py +10 -1
app.py CHANGED
@@ -82,7 +82,8 @@ def streaming_object_detection(recording_id: str, img):
82
  # convert outputs (bounding boxes and class logits) to COCO API
83
  # let's only keep detections with score > 0.9
84
 
85
- target_sizes = torch.tensor([img.size[::-1]])
 
86
  results = processor.post_process_object_detection(
87
  outputs, target_sizes=target_sizes, threshold=0.9
88
  )[0]
@@ -94,6 +95,14 @@ def streaming_object_detection(recording_id: str, img):
94
  array=results["boxes"],
95
  array_format=rr.Box2DFormat.XYXY,
96
  labels=[model.config.id2label[label.item()] for label in results["labels"]],
 
 
 
 
 
 
 
 
97
  ),
98
  )
99
 
 
82
  # convert outputs (bounding boxes and class logits) to COCO API
83
  # let's only keep detections with score > 0.9
84
 
85
+ height, width = img.shape[:2]
86
+ target_sizes = torch.tensor([[height, width]]) # [height, width] order
87
  results = processor.post_process_object_detection(
88
  outputs, target_sizes=target_sizes, threshold=0.9
89
  )[0]
 
95
  array=results["boxes"],
96
  array_format=rr.Box2DFormat.XYXY,
97
  labels=[model.config.id2label[label.item()] for label in results["labels"]],
98
+ colors=[
99
+ (
100
+ label.item() * 50 % 255,
101
+ (label.item() * 80 + 40) % 255,
102
+ (label.item() * 120 + 100) % 255,
103
+ )
104
+ for label in results["labels"]
105
+ ],
106
  ),
107
  )
108