ladyzoe's picture
create app.py
fd6a1e2 verified
raw
history blame
3.27 kB
import torch
from ultralytics import YOLO
from fastapi import FastAPI, UploadFile, File
from PIL import Image
import io
import os
import numpy as np
import cv2
# 為了解決可能的 WeightsUnpickler error,雖然在最新日誌中可能不再需要,但保留以防萬一
try:
# 僅在支援的情況下添加這個安全白名單
torch.serialization.add_safe_global("ultralytics.nn.tasks.DetectionModel", True)
except AttributeError:
# 如果 PyTorch 版本不支持,則跳過
print("PyTorch version does not support _WEIGHTS_ONLY_SAFE_GLOBAL_ALLOWLIST. Skipping add_safe_global.")
# 初始化 FastAPI 應用
app = FastAPI()
# 載入 YOLOv8 模型
# 確保 best.pt 檔案在相同的目錄中
try:
model = YOLO("best.pt")
print("YOLOv8 模型在 Hugging Face Space 載入成功!")
except Exception as e:
raise RuntimeError(f"無法載入 YOLOv8 模型於 Space: {e}")
@app.get("/")
async def read_root():
"""
根路徑,用於檢查 API 是否正常運行。
"""
return {"message": "Hugging Face Space API for Taiwan Black Bear Detection is running!"}
@app.post("/predict")
async def predict_image(file: UploadFile = File(...)):
"""
接收圖片檔案,進行台灣黑熊偵測,並返回帶有邊界框的圖片。
"""
try:
# 讀取上傳的圖片
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
image_np = np.array(image)
# 確保顏色通道正確 (PIL 是 RGB, OpenCV 是 BGR)
image_cv2 = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
# 執行推論
# model.predict 默認返回 Results 對象列表
results = model.predict(source=image_cv2, conf=0.25) # 你可以調整 conf 閾值
# 處理結果並繪製邊界框
output_image_np = image_cv2.copy()
for r in results:
# r.boxes 包含所有偵測到的物體
for box in r.boxes:
x1, y1, x2, y2 = map(int, box.xyxy[0]) # 獲取邊界框坐標
conf = round(float(box.conf[0]), 2) # 置信度
cls = int(box.cls[0]) # 類別 ID
name = model.names[cls] # 類別名稱
# 只繪製標籤為 'Taiwan-Black-Bear' 的偵測框
if name == 'Taiwan-Black-Bear':
# 繪製邊界框 (綠色,粗度 2)
cv2.rectangle(output_image_np, (x1, y1), (x2, y2), (0, 255, 0), 2)
# 繪製標籤和置信度
label = f'{name} {conf}'
cv2.putText(output_image_np, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
# 將處理後的圖片轉換回 PIL 格式以便返回
output_image_pil = Image.fromarray(cv2.cvtColor(output_image_np, cv2.COLOR_BGR2RGB))
# 將圖片轉換為 BytesIO 對象,以便作為響應發送
byte_arr = io.BytesIO()
output_image_pil.save(byte_arr, format='JPEG')
byte_arr.seek(0) # 將指針移到開頭
return StreamingResponse(byte_arr, media_type="image/jpeg")
except Exception as e:
return {"error": str(e)}, 500
# 引入 StreamingResponse
from fastapi.responses import StreamingResponse