fastapi-yoloe / app.py
wjm55
refactor init_model function to accept model_id parameter and update predict endpoint to use dynamic model initialization; added supervision library to requirements
1955b0a
raw
history blame
3.45 kB
from fastapi import FastAPI, UploadFile
from ultralytics import YOLOE
import io
from PIL import Image
import numpy as np
import os
from huggingface_hub import hf_hub_download
from ultralytics import YOLO
import requests
import supervision as sv
###
#pip install -q "git+https://github.com/THU-MIG/yoloe.git#subdirectory=third_party/CLIP"
#pip install -q "git+https://github.com/THU-MIG/yoloe.git#subdirectory=third_party/ml-mobileclip"
#pip install -q "git+https://github.com/THU-MIG/yoloe.git#subdirectory=third_party/lvis-api"
#pip install -q "git+https://github.com/THU-MIG/yoloe.git"
#wget -q https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_blt.pt
def init_model(model_id: str):
is_pf=True
# Create a models directory if it doesn't exist
os.makedirs("models", exist_ok=True)
filename = f"{model_id}-seg.pt" if not is_pf else f"{model_id}-seg-pf.pt"
path = hf_hub_download(repo_id="jameslahm/yoloe", filename=filename)
local_path = os.path.join("models", path)
# Download and load model
model = YOLOE(local_path)
model.eval()
return model
app = FastAPI()
@app.post("/predict")
async def predict(image_url: str,
texts: str = "hat",
model_id: str = "yoloe-11l",
conf: float = 0.25,
iou: float = 0.7
):
# Initialize model at startup
model = init_model(model_id)
# Set classes to filter
class_list = [text.strip() for text in texts.split(',')]
# Download and open image from URL
response = requests.get(image_url)
image = Image.open(io.BytesIO(response.content))
# Get text embeddings and set classes properly
text_embeddings = model.get_text_pe(class_list)
model.set_classes(class_list, text_embeddings)
# Run inference with the PIL Image
results = model.predict(source=image, conf=conf, iou=iou)
# Extract detection results
result = results[0]
# print(result)
detections = []
for box in result.boxes:
detection = {
"class": result.names[int(box.cls[0])],
"confidence": float(box.conf[0]),
"bbox": box.xyxy[0].tolist() # Convert bbox tensor to list
}
detections.append(detection)
print(detections)
# detections = sv.Detections.from_ultralytics(results[0])
# resolution_wh = image.size
# thickness = sv.calculate_optimal_line_thickness(resolution_wh=resolution_wh)
# text_scale = sv.calculate_optimal_text_scale(resolution_wh=resolution_wh)
# labels = [
# f"{class_name} {confidence:.2f}"
# for class_name, confidence
# in zip(detections['class_name'], detections.confidence)
# ]
# annotated_image = image.copy()
# annotated_image = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX, opacity=0.4).annotate(
# scene=annotated_image, detections=detections)
# annotated_image = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX, thickness=thickness).annotate(
# scene=annotated_image, detections=detections)
# annotated_image = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX, text_scale=text_scale, smart_position=True).annotate(
# scene=annotated_image, detections=detections, labels=labels)
return {"detections": detections}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)