Spaces:
Sleeping
Sleeping
create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from ultralytics import YOLO
|
| 3 |
+
from fastapi import FastAPI, UploadFile, File
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import io
|
| 6 |
+
import os
|
| 7 |
+
import numpy as np
|
| 8 |
+
import cv2
|
| 9 |
+
|
| 10 |
+
# 為了解決可能的 WeightsUnpickler error,雖然在最新日誌中可能不再需要,但保留以防萬一
|
| 11 |
+
try:
|
| 12 |
+
# 僅在支援的情況下添加這個安全白名單
|
| 13 |
+
torch.serialization.add_safe_global("ultralytics.nn.tasks.DetectionModel", True)
|
| 14 |
+
except AttributeError:
|
| 15 |
+
# 如果 PyTorch 版本不支持,則跳過
|
| 16 |
+
print("PyTorch version does not support _WEIGHTS_ONLY_SAFE_GLOBAL_ALLOWLIST. Skipping add_safe_global.")
|
| 17 |
+
|
| 18 |
+
# 初始化 FastAPI 應用
|
| 19 |
+
app = FastAPI()
|
| 20 |
+
|
| 21 |
+
# 載入 YOLOv8 模型
|
| 22 |
+
# 確保 best.pt 檔案在相同的目錄中
|
| 23 |
+
try:
|
| 24 |
+
model = YOLO("best.pt")
|
| 25 |
+
print("YOLOv8 模型在 Hugging Face Space 載入成功!")
|
| 26 |
+
except Exception as e:
|
| 27 |
+
raise RuntimeError(f"無法載入 YOLOv8 模型於 Space: {e}")
|
| 28 |
+
|
| 29 |
+
@app.get("/")
|
| 30 |
+
async def read_root():
|
| 31 |
+
"""
|
| 32 |
+
根路徑,用於檢查 API 是否正常運行。
|
| 33 |
+
"""
|
| 34 |
+
return {"message": "Hugging Face Space API for Taiwan Black Bear Detection is running!"}
|
| 35 |
+
|
| 36 |
+
@app.post("/predict")
|
| 37 |
+
async def predict_image(file: UploadFile = File(...)):
|
| 38 |
+
"""
|
| 39 |
+
接收圖片檔案,進行台灣黑熊偵測,並返回帶有邊界框的圖片。
|
| 40 |
+
"""
|
| 41 |
+
try:
|
| 42 |
+
# 讀取上傳的圖片
|
| 43 |
+
contents = await file.read()
|
| 44 |
+
image = Image.open(io.BytesIO(contents)).convert("RGB")
|
| 45 |
+
image_np = np.array(image)
|
| 46 |
+
|
| 47 |
+
# 確保顏色通道正確 (PIL 是 RGB, OpenCV 是 BGR)
|
| 48 |
+
image_cv2 = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
|
| 49 |
+
|
| 50 |
+
# 執行推論
|
| 51 |
+
# model.predict 默認返回 Results 對象列表
|
| 52 |
+
results = model.predict(source=image_cv2, conf=0.25) # 你可以調整 conf 閾值
|
| 53 |
+
|
| 54 |
+
# 處理結果並繪製邊界框
|
| 55 |
+
output_image_np = image_cv2.copy()
|
| 56 |
+
for r in results:
|
| 57 |
+
# r.boxes 包含所有偵測到的物體
|
| 58 |
+
for box in r.boxes:
|
| 59 |
+
x1, y1, x2, y2 = map(int, box.xyxy[0]) # 獲取邊界框坐標
|
| 60 |
+
conf = round(float(box.conf[0]), 2) # 置信度
|
| 61 |
+
cls = int(box.cls[0]) # 類別 ID
|
| 62 |
+
name = model.names[cls] # 類別名稱
|
| 63 |
+
|
| 64 |
+
# 只繪製標籤為 'Taiwan-Black-Bear' 的偵測框
|
| 65 |
+
if name == 'Taiwan-Black-Bear':
|
| 66 |
+
# 繪製邊界框 (綠色,粗度 2)
|
| 67 |
+
cv2.rectangle(output_image_np, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
| 68 |
+
# 繪製標籤和置信度
|
| 69 |
+
label = f'{name} {conf}'
|
| 70 |
+
cv2.putText(output_image_np, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
|
| 71 |
+
|
| 72 |
+
# 將處理後的圖片轉換回 PIL 格式以便返回
|
| 73 |
+
output_image_pil = Image.fromarray(cv2.cvtColor(output_image_np, cv2.COLOR_BGR2RGB))
|
| 74 |
+
|
| 75 |
+
# 將圖片轉換為 BytesIO 對象,以便作為響應發送
|
| 76 |
+
byte_arr = io.BytesIO()
|
| 77 |
+
output_image_pil.save(byte_arr, format='JPEG')
|
| 78 |
+
byte_arr.seek(0) # 將指針移到開頭
|
| 79 |
+
|
| 80 |
+
return StreamingResponse(byte_arr, media_type="image/jpeg")
|
| 81 |
+
|
| 82 |
+
except Exception as e:
|
| 83 |
+
return {"error": str(e)}, 500
|
| 84 |
+
|
| 85 |
+
# 引入 StreamingResponse
|
| 86 |
+
from fastapi.responses import StreamingResponse
|