Spaces:
Sleeping
Sleeping
File size: 3,265 Bytes
fd6a1e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import torch
from ultralytics import YOLO
from fastapi import FastAPI, UploadFile, File
from PIL import Image
import io
import os
import numpy as np
import cv2
# 為了解決可能的 WeightsUnpickler error,雖然在最新日誌中可能不再需要,但保留以防萬一
try:
# 僅在支援的情況下添加這個安全白名單
torch.serialization.add_safe_global("ultralytics.nn.tasks.DetectionModel", True)
except AttributeError:
# 如果 PyTorch 版本不支持,則跳過
print("PyTorch version does not support _WEIGHTS_ONLY_SAFE_GLOBAL_ALLOWLIST. Skipping add_safe_global.")
# 初始化 FastAPI 應用
app = FastAPI()
# 載入 YOLOv8 模型
# 確保 best.pt 檔案在相同的目錄中
try:
model = YOLO("best.pt")
print("YOLOv8 模型在 Hugging Face Space 載入成功!")
except Exception as e:
raise RuntimeError(f"無法載入 YOLOv8 模型於 Space: {e}")
@app.get("/")
async def read_root():
"""
根路徑,用於檢查 API 是否正常運行。
"""
return {"message": "Hugging Face Space API for Taiwan Black Bear Detection is running!"}
@app.post("/predict")
async def predict_image(file: UploadFile = File(...)):
"""
接收圖片檔案,進行台灣黑熊偵測,並返回帶有邊界框的圖片。
"""
try:
# 讀取上傳的圖片
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
image_np = np.array(image)
# 確保顏色通道正確 (PIL 是 RGB, OpenCV 是 BGR)
image_cv2 = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
# 執行推論
# model.predict 默認返回 Results 對象列表
results = model.predict(source=image_cv2, conf=0.25) # 你可以調整 conf 閾值
# 處理結果並繪製邊界框
output_image_np = image_cv2.copy()
for r in results:
# r.boxes 包含所有偵測到的物體
for box in r.boxes:
x1, y1, x2, y2 = map(int, box.xyxy[0]) # 獲取邊界框坐標
conf = round(float(box.conf[0]), 2) # 置信度
cls = int(box.cls[0]) # 類別 ID
name = model.names[cls] # 類別名稱
# 只繪製標籤為 'Taiwan-Black-Bear' 的偵測框
if name == 'Taiwan-Black-Bear':
# 繪製邊界框 (綠色,粗度 2)
cv2.rectangle(output_image_np, (x1, y1), (x2, y2), (0, 255, 0), 2)
# 繪製標籤和置信度
label = f'{name} {conf}'
cv2.putText(output_image_np, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
# 將處理後的圖片轉換回 PIL 格式以便返回
output_image_pil = Image.fromarray(cv2.cvtColor(output_image_np, cv2.COLOR_BGR2RGB))
# 將圖片轉換為 BytesIO 對象,以便作為響應發送
byte_arr = io.BytesIO()
output_image_pil.save(byte_arr, format='JPEG')
byte_arr.seek(0) # 將指針移到開頭
return StreamingResponse(byte_arr, media_type="image/jpeg")
except Exception as e:
return {"error": str(e)}, 500
# 引入 StreamingResponse
from fastapi.responses import StreamingResponse |