sam / app.py
sezer91's picture
ds
3fb7c45
raw
history blame
2.96 kB
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from PIL import Image
import numpy as np
from transformers import SamModel, SamProcessor
import io
import base64
import torch
import uvicorn
app = FastAPI(title="SAM-ViT-Base API")
# SAM modelini ve işlemciyi yükle
model = SamModel.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
@app.post("/segment/")
async def segment_image(file: UploadFile = File(...)):
try:
# Görüntüyü oku
image_data = await file.read()
image = Image.open(io.BytesIO(image_data)).convert("RGB")
# Görüntü boyutlarını al
original_width, original_height = image.size
if original_width < 64 or original_height < 64:
raise HTTPException(status_code=400, detail=f"Görüntü boyutu çok küçük: {original_width}x{original_height}. Minimum 64x64 piksel olmalı.")
# Görüntüyü işlemciye hazırla
inputs = processor(image, return_tensors="pt", do_rescale=True, do_resize=True)
# Model ile segmentasyon yap
with torch.no_grad():
outputs = model(**inputs)
# Maskeleri al
masks = outputs.pred_masks.detach().cpu().numpy() # Shape: (batch_size, num_masks, height, width)
if masks.shape[1] == 0:
raise HTTPException(status_code=500, detail="Hiç maske üretilmedi.")
# En iyi maskeyi seç
iou_scores = outputs.iou_scores.detach().cpu().numpy() # Shape: (batch_size, num_masks)
if iou_scores.shape[1] > 1:
best_mask_idx = np.argmax(iou_scores[0]) # En yüksek skora sahip maskeyi seç
else:
best_mask_idx = 0 # Tek maske varsa onu kullan
mask = masks[0][best_mask_idx] # Shape: (height, width)
# Maske şeklini kontrol et
if len(mask.shape) != 2:
raise HTTPException(status_code=500, detail=f"Hatalı maske şekli: {mask.shape}. 2D matris bekleniyor.")
# Maskeyi binary hale getir
mask = (mask > 0).astype(np.uint8) * 255
# Maskeyi orijinal görüntü boyutlarına yeniden boyutlandır
mask_image = Image.fromarray(mask).resize((original_width, original_height), Image.NEAREST)
# Maskeyi PNG olarak kaydet
buffered = io.BytesIO()
mask_image.save(buffered, format="PNG")
mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
return JSONResponse(content={"mask": f"data:image/png;base64,{mask_base64}"})
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
async def root():
return {"message": "SAM-ViT-Base API çalışıyor. /segment endpoint'ine görüntü yükleyin."}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)