ladyzoe commited on
Commit
fd6a1e2
·
verified ·
1 Parent(s): b00c94a

create app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -0
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from ultralytics import YOLO
3
+ from fastapi import FastAPI, UploadFile, File
4
+ from PIL import Image
5
+ import io
6
+ import os
7
+ import numpy as np
8
+ import cv2
9
+
10
+ # 為了解決可能的 WeightsUnpickler error,雖然在最新日誌中可能不再需要,但保留以防萬一
11
+ try:
12
+ # 僅在支援的情況下添加這個安全白名單
13
+ torch.serialization.add_safe_global("ultralytics.nn.tasks.DetectionModel", True)
14
+ except AttributeError:
15
+ # 如果 PyTorch 版本不支持,則跳過
16
+ print("PyTorch version does not support _WEIGHTS_ONLY_SAFE_GLOBAL_ALLOWLIST. Skipping add_safe_global.")
17
+
18
+ # 初始化 FastAPI 應用
19
+ app = FastAPI()
20
+
21
+ # 載入 YOLOv8 模型
22
+ # 確保 best.pt 檔案在相同的目錄中
23
+ try:
24
+ model = YOLO("best.pt")
25
+ print("YOLOv8 模型在 Hugging Face Space 載入成功!")
26
+ except Exception as e:
27
+ raise RuntimeError(f"無法載入 YOLOv8 模型於 Space: {e}")
28
+
29
+ @app.get("/")
30
+ async def read_root():
31
+ """
32
+ 根路徑,用於檢查 API 是否正常運行。
33
+ """
34
+ return {"message": "Hugging Face Space API for Taiwan Black Bear Detection is running!"}
35
+
36
+ @app.post("/predict")
37
+ async def predict_image(file: UploadFile = File(...)):
38
+ """
39
+ 接收圖片檔案,進行台灣黑熊偵測,並返回帶有邊界框的圖片。
40
+ """
41
+ try:
42
+ # 讀取上傳的圖片
43
+ contents = await file.read()
44
+ image = Image.open(io.BytesIO(contents)).convert("RGB")
45
+ image_np = np.array(image)
46
+
47
+ # 確保顏色通道正確 (PIL 是 RGB, OpenCV 是 BGR)
48
+ image_cv2 = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
49
+
50
+ # 執行推論
51
+ # model.predict 默認返回 Results 對象列表
52
+ results = model.predict(source=image_cv2, conf=0.25) # 你可以調整 conf 閾值
53
+
54
+ # 處理結果並繪製邊界框
55
+ output_image_np = image_cv2.copy()
56
+ for r in results:
57
+ # r.boxes 包含所有偵測到的物體
58
+ for box in r.boxes:
59
+ x1, y1, x2, y2 = map(int, box.xyxy[0]) # 獲取邊界框坐標
60
+ conf = round(float(box.conf[0]), 2) # 置信度
61
+ cls = int(box.cls[0]) # 類別 ID
62
+ name = model.names[cls] # 類別名稱
63
+
64
+ # 只繪製標籤為 'Taiwan-Black-Bear' 的偵測框
65
+ if name == 'Taiwan-Black-Bear':
66
+ # 繪製邊界框 (綠色,粗度 2)
67
+ cv2.rectangle(output_image_np, (x1, y1), (x2, y2), (0, 255, 0), 2)
68
+ # 繪製標籤和置信度
69
+ label = f'{name} {conf}'
70
+ cv2.putText(output_image_np, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
71
+
72
+ # 將處理後的圖片轉換回 PIL 格式以便返回
73
+ output_image_pil = Image.fromarray(cv2.cvtColor(output_image_np, cv2.COLOR_BGR2RGB))
74
+
75
+ # 將圖片轉換為 BytesIO 對象,以便作為響應發送
76
+ byte_arr = io.BytesIO()
77
+ output_image_pil.save(byte_arr, format='JPEG')
78
+ byte_arr.seek(0) # 將指針移到開頭
79
+
80
+ return StreamingResponse(byte_arr, media_type="image/jpeg")
81
+
82
+ except Exception as e:
83
+ return {"error": str(e)}, 500
84
+
85
+ # 引入 StreamingResponse
86
+ from fastapi.responses import StreamingResponse