Spaces:
Sleeping
Sleeping
Upload 2 files
Browse files- app.py +46 -13
- requirements.txt +3 -1
app.py
CHANGED
@@ -1,20 +1,53 @@
|
|
1 |
-
from fastapi import FastAPI
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from pydantic import BaseModel
|
3 |
|
4 |
app = FastAPI(title="SAM_MedTesting")
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
12 |
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
return {"generated_text": out}
|
17 |
|
18 |
-
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, File, UploadFile
|
2 |
+
from transformers import SamModel, SamProcessor
|
3 |
+
import torch
|
4 |
+
from PIL import Image
|
5 |
+
import io
|
6 |
+
import numpy as np
|
7 |
+
from fastapi.responses import StreamingResponse
|
8 |
from pydantic import BaseModel
|
9 |
|
10 |
app = FastAPI(title="SAM_MedTesting")
|
11 |
+
MODEL_ID = "facebook/sam-vit-base"
|
12 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
13 |
+
processor = SamProcessor.from_pretrained(MODEL_ID)
|
14 |
+
model = SamModel.from_pretrained(MODEL_ID).to(device)
|
15 |
|
16 |
+
@app.post("/segment")
|
17 |
+
async def segment(
|
18 |
+
file: UploadFile = File(...),
|
19 |
+
# an optional list of prompt points [[x1,y1], [x2,y2], …]
|
20 |
+
input_points: list[list[int]] | None = None
|
21 |
+
):
|
22 |
+
|
23 |
+
# read & decode image
|
24 |
+
data = await file.read()
|
25 |
+
img = Image.open(io.BytesIO(data)).convert("RGB")
|
26 |
|
27 |
+
# prepare inputs for SAM
|
28 |
+
inputs = processor(
|
29 |
+
img,
|
30 |
+
input_points=[input_points] if input_points else None,
|
31 |
+
return_tensors="pt"
|
32 |
+
).to(device)
|
33 |
|
34 |
+
# forward pass
|
35 |
+
with torch.no_grad():
|
36 |
+
outputs = model(**inputs)
|
|
|
37 |
|
38 |
+
# post-process to binary masks
|
39 |
+
masks = processor.image_processor.post_process_masks(
|
40 |
+
outputs.pred_masks.cpu(),
|
41 |
+
inputs["original_sizes"].cpu(),
|
42 |
+
inputs["reshaped_input_sizes"].cpu()
|
43 |
+
)
|
44 |
+
# take the first mask
|
45 |
+
mask_np = (masks[0][0] * 255).astype(np.uint8)
|
46 |
+
|
47 |
+
# turn it into a PNG
|
48 |
+
pil_mask = Image.fromarray(mask_np)
|
49 |
+
buf = io.BytesIO()
|
50 |
+
pil_mask.save(buf, format="PNG")
|
51 |
+
buf.seek(0)
|
52 |
+
|
53 |
+
return StreamingResponse(buf, media_type="image/png")
|
requirements.txt
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
fastapi
|
2 |
uvicorn[standard]
|
3 |
transformers
|
4 |
-
torch
|
|
|
|
|
|
1 |
fastapi
|
2 |
uvicorn[standard]
|
3 |
transformers
|
4 |
+
torch
|
5 |
+
cv2
|
6 |
+
Pillow
|