SuriRaja commited on
Commit
5116daa
·
1 Parent(s): 9c4a4fe

Update services/thermal_service.py

Browse files
Files changed (1) hide show
  1. services/thermal_service.py +26 -16
services/thermal_service.py CHANGED
@@ -1,14 +1,12 @@
1
- import os
2
  import torch
 
3
  from ultralytics import YOLO
4
- from torch.serialization import add_safe_globals
5
-
6
- # Import all necessary classes explicitly
7
  from ultralytics.nn.tasks import DetectionModel
8
  from ultralytics.nn.modules import Conv
9
- import torch.nn.modules.container as container
10
 
11
- # ✅ Register all trusted classes
12
  add_safe_globals({
13
  container.Sequential: "torch.nn.modules.container.Sequential",
14
  container.ModuleList: "torch.nn.modules.container.ModuleList",
@@ -17,33 +15,45 @@ add_safe_globals({
17
  Conv: "ultralytics.nn.modules.Conv"
18
  })
19
 
 
 
 
 
 
 
 
20
  def load_yolo_model_safely(model_path: str = 'yolov8n.pt') -> YOLO:
21
  """
22
- Safely loads a YOLO model with necessary PyTorch 2.6+ safe globals registration
23
- and ensures auto-download if missing.
24
  """
25
  if not os.path.isfile(model_path):
26
- print(f"[INFO] Model {model_path} not found locally. Auto-downloading...")
 
 
27
  try:
28
  model = YOLO(model_path)
29
- print(f"[INFO] YOLO model {model_path} loaded successfully.")
30
  return model
31
  except Exception as e:
32
- print(f"[ERROR] Could not load YOLO model: {e}")
33
- raise
 
 
 
 
 
 
34
 
35
- # ✅ Load YOLO model globally
36
  thermal_model = load_yolo_model_safely()
37
 
38
- def detect_thermal_anomalies(image_path: str):
39
  """
40
- Detects thermal anomalies in a given image frame using YOLO.
41
  """
42
  results = thermal_model(image_path)
43
  flagged = []
44
  for r in results:
45
  for box in r.boxes:
46
- # Simulate thermal detection by confidence threshold
47
  if box.conf > 0.7:
48
  flagged.append({
49
  "confidence": float(box.conf),
 
 
1
  import torch
2
+ import os
3
  from ultralytics import YOLO
4
+ from torch.serialization import add_safe_globals, safe_load
5
+ import torch.nn.modules.container as container
 
6
  from ultralytics.nn.tasks import DetectionModel
7
  from ultralytics.nn.modules import Conv
 
8
 
9
+ # ✅ Register all necessary classes
10
  add_safe_globals({
11
  container.Sequential: "torch.nn.modules.container.Sequential",
12
  container.ModuleList: "torch.nn.modules.container.ModuleList",
 
15
  Conv: "ultralytics.nn.modules.Conv"
16
  })
17
 
18
+ def custom_safe_load(filepath):
19
+ """
20
+ Force torch to load YOLO weights without weights_only=True limitation.
21
+ """
22
+ with open(filepath, 'rb') as f:
23
+ return torch.load(f, map_location='cpu', weights_only=False)
24
+
25
  def load_yolo_model_safely(model_path: str = 'yolov8n.pt') -> YOLO:
26
  """
27
+ Custom safe load for YOLO models trained before torch 2.6 weight-only enforcement.
 
28
  """
29
  if not os.path.isfile(model_path):
30
+ print(f"[INFO] Downloading {model_path}...")
31
+ # Will auto-download internally by Ultralytics YOLO
32
+
33
  try:
34
  model = YOLO(model_path)
 
35
  return model
36
  except Exception as e:
37
+ # If default load fails, force fallback load
38
+ print(f"[WARNING] Normal YOLO load failed: {e}")
39
+ print(f"[INFO] Trying manual safe load...")
40
+
41
+ # Manual fallback load
42
+ weights = custom_safe_load(model_path)
43
+ model = YOLO(model=weights) # Load model from raw weights
44
+ return model
45
 
46
+ # ✅ Initialize model
47
  thermal_model = load_yolo_model_safely()
48
 
49
+ def detect_thermal_anomalies(image_path):
50
  """
51
+ Detect anomalies using YOLO model.
52
  """
53
  results = thermal_model(image_path)
54
  flagged = []
55
  for r in results:
56
  for box in r.boxes:
 
57
  if box.conf > 0.7:
58
  flagged.append({
59
  "confidence": float(box.conf),