Axzyl commited on
Commit
2474580
·
verified ·
1 Parent(s): c3728b8

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +46 -13
  2. requirements.txt +3 -1
app.py CHANGED
@@ -1,20 +1,53 @@
1
- from fastapi import FastAPI
 
 
 
 
 
 
2
  from pydantic import BaseModel
3
 
4
  app = FastAPI(title="SAM_MedTesting")
 
 
 
 
5
 
6
- class GenerationRequest(BaseModel):
7
- prompt: str
8
- max_new_tokens: int = 50
 
 
 
 
 
 
 
9
 
10
- class GenerationResponse(BaseModel):
11
- generated_text: str
 
 
 
 
12
 
13
- @app.post("/generate", response_model=GenerationResponse)
14
- def generate(req: GenerationRequest):
15
- out = f"hello world: {req}"
16
- return {"generated_text": out}
17
 
18
- @app.get("/health")
19
- def health():
20
- return {"status": "ok"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile
2
+ from transformers import SamModel, SamProcessor
3
+ import torch
4
+ from PIL import Image
5
+ import io
6
+ import numpy as np
7
+ from fastapi.responses import StreamingResponse
8
  from pydantic import BaseModel
9
 
10
  app = FastAPI(title="SAM_MedTesting")
11
+ MODEL_ID = "facebook/sam-vit-base"
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ processor = SamProcessor.from_pretrained(MODEL_ID)
14
+ model = SamModel.from_pretrained(MODEL_ID).to(device)
15
 
16
+ @app.post("/segment")
17
+ async def segment(
18
+ file: UploadFile = File(...),
19
+ # an optional list of prompt points [[x1,y1], [x2,y2], …]
20
+ input_points: list[list[int]] | None = None
21
+ ):
22
+
23
+ # read & decode image
24
+ data = await file.read()
25
+ img = Image.open(io.BytesIO(data)).convert("RGB")
26
 
27
+ # prepare inputs for SAM
28
+ inputs = processor(
29
+ img,
30
+ input_points=[input_points] if input_points else None,
31
+ return_tensors="pt"
32
+ ).to(device)
33
 
34
+ # forward pass
35
+ with torch.no_grad():
36
+ outputs = model(**inputs)
 
37
 
38
+ # post-process to binary masks
39
+ masks = processor.image_processor.post_process_masks(
40
+ outputs.pred_masks.cpu(),
41
+ inputs["original_sizes"].cpu(),
42
+ inputs["reshaped_input_sizes"].cpu()
43
+ )
44
+ # take the first mask
45
+ mask_np = (masks[0][0] * 255).astype(np.uint8)
46
+
47
+ # turn it into a PNG
48
+ pil_mask = Image.fromarray(mask_np)
49
+ buf = io.BytesIO()
50
+ pil_mask.save(buf, format="PNG")
51
+ buf.seek(0)
52
+
53
+ return StreamingResponse(buf, media_type="image/png")
requirements.txt CHANGED
@@ -1,4 +1,6 @@
1
  fastapi
2
  uvicorn[standard]
3
  transformers
4
- torch
 
 
 
1
  fastapi
2
  uvicorn[standard]
3
  transformers
4
+ torch
5
+ cv2
6
+ Pillow