shekzee commited on
Commit
64a04e7
·
verified ·
1 Parent(s): 21495d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -23
app.py CHANGED
@@ -1,31 +1,27 @@
1
  from fastapi import FastAPI, UploadFile, File
2
- from fastapi.responses import JSONResponse
3
  from PIL import Image
4
- import torch, torchvision.transforms as T
5
- from transformers import MobileNetV2ForSemanticSegmentation
6
- import io
7
-
8
- # Load the model
9
- model = MobileNetV2ForSemanticSegmentation.from_pretrained("seg_model")
10
- model.eval()
11
-
12
- preprocess = T.Compose([
13
- T.Resize(513),
14
- T.ToTensor(),
15
- T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
16
- ])
17
 
18
  app = FastAPI()
19
 
20
- @app.get("/")
21
- def root():
22
- return {"status": "API up for segmentation"}
 
 
 
 
23
 
24
  @app.post("/predict")
25
  async def predict(file: UploadFile = File(...)):
26
- img = Image.open(await file.read()).convert("RGB")
27
- x = preprocess(img).unsqueeze(0)
28
- with torch.no_grad():
29
- outputs = model(x).logits
30
- seg = outputs.argmax(1)[0].tolist()
31
- return JSONResponse(content={"segmentation_mask": seg})
 
 
 
1
  from fastapi import FastAPI, UploadFile, File
2
+ from fastapi.middleware.cors import CORSMiddleware
3
  from PIL import Image
4
+ from io import BytesIO
5
+ import numpy as np
6
+ import tensorflow as tf
 
 
 
 
 
 
 
 
 
 
7
 
8
  app = FastAPI()
9
 
10
+ # CORS config (optional)
11
+ app.add_middleware(
12
+ CORSMiddleware,
13
+ allow_origins=["*"],
14
+ allow_methods=["*"],
15
+ allow_headers=["*"],
16
+ )
17
 
18
  @app.post("/predict")
19
  async def predict(file: UploadFile = File(...)):
20
+ contents = await file.read()
21
+ img = Image.open(BytesIO(contents)).convert("RGB")
22
+ img = img.resize((256, 256)) # or whatever your model expects
23
+ arr = np.array(img) / 255.0
24
+ arr = np.expand_dims(arr, 0)
25
+ # prediction = model.predict(arr)
26
+ # result = do_something_with_prediction(prediction)
27
+ return {"success": True}