SuriRaja commited on
Commit
45a917a
·
1 Parent(s): 7766ce0

Update services/thermal_service.py

Browse files
Files changed (1) hide show
  1. services/thermal_service.py +16 -16
services/thermal_service.py CHANGED
@@ -1,23 +1,23 @@
1
- import cv2
 
2
  from PIL import Image
3
- from transformers import pipeline
4
 
5
- thermal_detector = pipeline("object-detection", model="facebook/detr-resnet-50")
 
 
6
 
7
  def detect_thermal_anomalies(frame):
8
- pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
9
- results = thermal_detector(pil_img)
10
- boxes = []
11
 
12
- for result in results:
13
- box = result["box"]
14
- boxes.append((
15
- int(box["xmin"]),
16
- int(box["ymin"]),
17
- int(box["xmax"]),
18
- int(box["ymax"]),
19
- result["score"],
20
- result["label"],
21
- ))
22
 
23
  return boxes
 
1
+ from transformers import DetrImageProcessor, DetrForObjectDetection
2
+ import torch
3
  from PIL import Image
4
+ import cv2
5
 
6
+ # Load model
7
+ processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
8
+ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
9
 
10
  def detect_thermal_anomalies(frame):
11
+ image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
12
+ inputs = processor(images=image, return_tensors="pt")
13
+ outputs = model(**inputs)
14
 
15
+ target_sizes = torch.tensor([image.size[::-1]])
16
+ results = processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)[0]
17
+
18
+ boxes = []
19
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
20
+ if score >= 0.9:
21
+ boxes.append(box.tolist())
 
 
 
22
 
23
  return boxes