|
from fastapi import FastAPI, File, UploadFile, HTTPException |
|
from fastapi.responses import JSONResponse |
|
from PIL import Image |
|
import numpy as np |
|
from transformers import SamModel, SamProcessor |
|
import io |
|
import base64 |
|
import torch |
|
import uvicorn |
|
|
|
app = FastAPI(title="SAM-ViT-Base API") |
|
|
|
|
|
model = SamModel.from_pretrained("facebook/sam-vit-base") |
|
processor = SamProcessor.from_pretrained("facebook/sam-vit-base") |
|
|
|
@app.post("/segment/") |
|
async def segment_image(file: UploadFile = File(...)): |
|
try: |
|
|
|
image_data = await file.read() |
|
image = Image.open(io.BytesIO(image_data)).convert("RGB") |
|
|
|
|
|
original_width, original_height = image.size |
|
if original_width < 64 or original_height < 64: |
|
raise HTTPException(status_code=400, detail=f"Görüntü boyutu çok küçük: {original_width}x{original_height}. Minimum 64x64 piksel olmalı.") |
|
|
|
|
|
inputs = processor(image, return_tensors="pt", do_rescale=True, do_resize=True) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
|
|
masks = outputs.pred_masks.detach().cpu().numpy() |
|
if masks.shape[1] == 0: |
|
raise HTTPException(status_code=500, detail="Hiç maske üretilmedi.") |
|
|
|
|
|
iou_scores = outputs.iou_scores.detach().cpu().numpy() |
|
if iou_scores.shape[1] > 1: |
|
best_mask_idx = np.argmax(iou_scores[0]) |
|
else: |
|
best_mask_idx = 0 |
|
mask = masks[0][best_mask_idx] |
|
|
|
|
|
if len(mask.shape) != 2: |
|
raise HTTPException(status_code=500, detail=f"Hatalı maske şekli: {mask.shape}. 2D matris bekleniyor.") |
|
|
|
|
|
mask = (mask > 0).astype(np.uint8) * 255 |
|
|
|
|
|
mask_image = Image.fromarray(mask).resize((original_width, original_height), Image.NEAREST) |
|
|
|
|
|
buffered = io.BytesIO() |
|
mask_image.save(buffered, format="PNG") |
|
mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
|
|
return JSONResponse(content={"mask": f"data:image/png;base64,{mask_base64}"}) |
|
|
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@app.get("/") |
|
async def root(): |
|
return {"message": "SAM-ViT-Base API çalışıyor. /segment endpoint'ine görüntü yükleyin."} |
|
|
|
if __name__ == "__main__": |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |