SuriRaja commited on
Commit
2827306
·
1 Parent(s): 5116daa

Update services/thermal_service.py

Browse files
Files changed (1) hide show
  1. services/thermal_service.py +10 -27
services/thermal_service.py CHANGED
@@ -1,12 +1,12 @@
1
- import torch
2
  import os
 
3
  from ultralytics import YOLO
4
- from torch.serialization import add_safe_globals, safe_load
5
  import torch.nn.modules.container as container
6
  from ultralytics.nn.tasks import DetectionModel
7
  from ultralytics.nn.modules import Conv
8
 
9
- # ✅ Register all necessary classes
10
  add_safe_globals({
11
  container.Sequential: "torch.nn.modules.container.Sequential",
12
  container.ModuleList: "torch.nn.modules.container.ModuleList",
@@ -15,40 +15,23 @@ add_safe_globals({
15
  Conv: "ultralytics.nn.modules.Conv"
16
  })
17
 
18
- def custom_safe_load(filepath):
19
- """
20
- Force torch to load YOLO weights without weights_only=True limitation.
21
- """
22
- with open(filepath, 'rb') as f:
23
- return torch.load(f, map_location='cpu', weights_only=False)
24
-
25
- def load_yolo_model_safely(model_path: str = 'yolov8n.pt') -> YOLO:
26
  """
27
- Custom safe load for YOLO models trained before torch 2.6 weight-only enforcement.
28
  """
29
- if not os.path.isfile(model_path):
30
- print(f"[INFO] Downloading {model_path}...")
31
- # Will auto-download internally by Ultralytics YOLO
32
-
33
  try:
34
- model = YOLO(model_path)
 
35
  return model
36
  except Exception as e:
37
- # If default load fails, force fallback load
38
- print(f"[WARNING] Normal YOLO load failed: {e}")
39
- print(f"[INFO] Trying manual safe load...")
40
-
41
- # Manual fallback load
42
- weights = custom_safe_load(model_path)
43
- model = YOLO(model=weights) # Load model from raw weights
44
- return model
45
 
46
- # ✅ Initialize model
47
  thermal_model = load_yolo_model_safely()
48
 
49
  def detect_thermal_anomalies(image_path):
50
  """
51
- Detect anomalies using YOLO model.
52
  """
53
  results = thermal_model(image_path)
54
  flagged = []
 
 
1
  import os
2
+ import torch
3
  from ultralytics import YOLO
4
+ from torch.serialization import add_safe_globals
5
  import torch.nn.modules.container as container
6
  from ultralytics.nn.tasks import DetectionModel
7
  from ultralytics.nn.modules import Conv
8
 
9
+ # ✅ Register all necessary safe classes
10
  add_safe_globals({
11
  container.Sequential: "torch.nn.modules.container.Sequential",
12
  container.ModuleList: "torch.nn.modules.container.ModuleList",
 
15
  Conv: "ultralytics.nn.modules.Conv"
16
  })
17
 
18
+ def load_yolo_model_safely():
 
 
 
 
 
 
 
19
  """
20
+ Use direct pretrained YOLOv8n model from Ultralytics Hub (no local weights download needed).
21
  """
 
 
 
 
22
  try:
23
+ model = YOLO('yolov8n.pt') # pretrained small model directly from Ultralytics hub
24
+ print("[INFO] YOLOv8 model loaded successfully.")
25
  return model
26
  except Exception as e:
27
+ print(f"[ERROR] Failed to load YOLO model: {e}")
28
+ raise
 
 
 
 
 
 
29
 
 
30
  thermal_model = load_yolo_model_safely()
31
 
32
  def detect_thermal_anomalies(image_path):
33
  """
34
+ Detect anomalies in an image using the loaded YOLO model.
35
  """
36
  results = thermal_model(image_path)
37
  flagged = []