Spaces:
Sleeping
Sleeping
import torch | |
from ultralytics import YOLO | |
from fastapi import FastAPI, UploadFile, File | |
from PIL import Image | |
import io | |
import os | |
import numpy as np | |
import cv2 | |
# 為了解決可能的 WeightsUnpickler error,雖然在最新日誌中可能不再需要,但保留以防萬一 | |
try: | |
# 僅在支援的情況下添加這個安全白名單 | |
torch.serialization.add_safe_global("ultralytics.nn.tasks.DetectionModel", True) | |
except AttributeError: | |
# 如果 PyTorch 版本不支持,則跳過 | |
print("PyTorch version does not support _WEIGHTS_ONLY_SAFE_GLOBAL_ALLOWLIST. Skipping add_safe_global.") | |
# 初始化 FastAPI 應用 | |
app = FastAPI() | |
# 載入 YOLOv8 模型 | |
# 確保 best.pt 檔案在相同的目錄中 | |
try: | |
model = YOLO("best.pt") | |
print("YOLOv8 模型在 Hugging Face Space 載入成功!") | |
except Exception as e: | |
raise RuntimeError(f"無法載入 YOLOv8 模型於 Space: {e}") | |
async def read_root(): | |
""" | |
根路徑,用於檢查 API 是否正常運行。 | |
""" | |
return {"message": "Hugging Face Space API for Taiwan Black Bear Detection is running!"} | |
async def predict_image(file: UploadFile = File(...)): | |
""" | |
接收圖片檔案,進行台灣黑熊偵測,並返回帶有邊界框的圖片。 | |
""" | |
try: | |
# 讀取上傳的圖片 | |
contents = await file.read() | |
image = Image.open(io.BytesIO(contents)).convert("RGB") | |
image_np = np.array(image) | |
# 確保顏色通道正確 (PIL 是 RGB, OpenCV 是 BGR) | |
image_cv2 = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) | |
# 執行推論 | |
# model.predict 默認返回 Results 對象列表 | |
results = model.predict(source=image_cv2, conf=0.25) # 你可以調整 conf 閾值 | |
# 處理結果並繪製邊界框 | |
output_image_np = image_cv2.copy() | |
for r in results: | |
# r.boxes 包含所有偵測到的物體 | |
for box in r.boxes: | |
x1, y1, x2, y2 = map(int, box.xyxy[0]) # 獲取邊界框坐標 | |
conf = round(float(box.conf[0]), 2) # 置信度 | |
cls = int(box.cls[0]) # 類別 ID | |
name = model.names[cls] # 類別名稱 | |
# 只繪製標籤為 'Taiwan-Black-Bear' 的偵測框 | |
if name == 'Taiwan-Black-Bear': | |
# 繪製邊界框 (綠色,粗度 2) | |
cv2.rectangle(output_image_np, (x1, y1), (x2, y2), (0, 255, 0), 2) | |
# 繪製標籤和置信度 | |
label = f'{name} {conf}' | |
cv2.putText(output_image_np, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2) | |
# 將處理後的圖片轉換回 PIL 格式以便返回 | |
output_image_pil = Image.fromarray(cv2.cvtColor(output_image_np, cv2.COLOR_BGR2RGB)) | |
# 將圖片轉換為 BytesIO 對象,以便作為響應發送 | |
byte_arr = io.BytesIO() | |
output_image_pil.save(byte_arr, format='JPEG') | |
byte_arr.seek(0) # 將指針移到開頭 | |
return StreamingResponse(byte_arr, media_type="image/jpeg") | |
except Exception as e: | |
return {"error": str(e)}, 500 | |
# 引入 StreamingResponse | |
from fastapi.responses import StreamingResponse |