File size: 3,265 Bytes
fd6a1e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch
from ultralytics import YOLO
from fastapi import FastAPI, UploadFile, File
from PIL import Image
import io
import os
import numpy as np
import cv2

# 為了解決可能的 WeightsUnpickler error,雖然在最新日誌中可能不再需要,但保留以防萬一
try:
    # 僅在支援的情況下添加這個安全白名單
    torch.serialization.add_safe_global("ultralytics.nn.tasks.DetectionModel", True)
except AttributeError:
    # 如果 PyTorch 版本不支持,則跳過
    print("PyTorch version does not support _WEIGHTS_ONLY_SAFE_GLOBAL_ALLOWLIST. Skipping add_safe_global.")

# 初始化 FastAPI 應用
app = FastAPI()

# 載入 YOLOv8 模型
# 確保 best.pt 檔案在相同的目錄中
try:
    model = YOLO("best.pt")
    print("YOLOv8 模型在 Hugging Face Space 載入成功!")
except Exception as e:
    raise RuntimeError(f"無法載入 YOLOv8 模型於 Space: {e}")

@app.get("/")
async def read_root():
    """
    根路徑,用於檢查 API 是否正常運行。
    """
    return {"message": "Hugging Face Space API for Taiwan Black Bear Detection is running!"}

@app.post("/predict")
async def predict_image(file: UploadFile = File(...)):
    """
    接收圖片檔案,進行台灣黑熊偵測,並返回帶有邊界框的圖片。
    """
    try:
        # 讀取上傳的圖片
        contents = await file.read()
        image = Image.open(io.BytesIO(contents)).convert("RGB")
        image_np = np.array(image)

        # 確保顏色通道正確 (PIL 是 RGB, OpenCV 是 BGR)
        image_cv2 = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)

        # 執行推論
        # model.predict 默認返回 Results 對象列表
        results = model.predict(source=image_cv2, conf=0.25) # 你可以調整 conf 閾值

        # 處理結果並繪製邊界框
        output_image_np = image_cv2.copy()
        for r in results:
            # r.boxes 包含所有偵測到的物體
            for box in r.boxes:
                x1, y1, x2, y2 = map(int, box.xyxy[0]) # 獲取邊界框坐標
                conf = round(float(box.conf[0]), 2) # 置信度
                cls = int(box.cls[0]) # 類別 ID
                name = model.names[cls] # 類別名稱

                # 只繪製標籤為 'Taiwan-Black-Bear' 的偵測框
                if name == 'Taiwan-Black-Bear':
                    # 繪製邊界框 (綠色,粗度 2)
                    cv2.rectangle(output_image_np, (x1, y1), (x2, y2), (0, 255, 0), 2)
                    # 繪製標籤和置信度
                    label = f'{name} {conf}'
                    cv2.putText(output_image_np, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)

        # 將處理後的圖片轉換回 PIL 格式以便返回
        output_image_pil = Image.fromarray(cv2.cvtColor(output_image_np, cv2.COLOR_BGR2RGB))

        # 將圖片轉換為 BytesIO 對象,以便作為響應發送
        byte_arr = io.BytesIO()
        output_image_pil.save(byte_arr, format='JPEG')
        byte_arr.seek(0) # 將指針移到開頭

        return StreamingResponse(byte_arr, media_type="image/jpeg")

    except Exception as e:
        return {"error": str(e)}, 500

# 引入 StreamingResponse
from fastapi.responses import StreamingResponse