MyDockerAPI / app.py
MasumBhuiyan's picture
Update app.py
05871c1 verified
raw
history blame
1.46 kB
import numpy as np
from PIL import Image
import torch
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
from pydantic import BaseModel
app = FastAPI()
# Load the pre-trained model from within the container
model_uri = "model.pth"
model = torch.load(model_uri, weights_only=False)
# Define input schema for JSON requests
class ImageInput(BaseModel):
image_path: str
# Preprocess the image
def preprocess_image(image):
image = image.convert('L') # Convert to grayscale
image = image.resize((28, 28))
image = np.array(image) / 255.0 # Normalize to [0, 1]
image = (image - 0.1307) / 0.3081 # Standardize
image = torch.tensor(image).unsqueeze(0).float() # Convert to tensor with batch dimension
return image
# Root endpoint
@app.get("/")
def greet_json():
return {"Hello": "World!"}
# Predict endpoint for JSON input
@app.post("/predict")
async def predict_image(file: UploadFile = File(...)):
try:
# Read and preprocess the uploaded image
image = Image.open(file.file)
image = preprocess_image(image)
# Make prediction
model.eval()
with torch.no_grad():
output = model(image)
prediction = output.argmax(dim=1).item()
return JSONResponse(content={"prediction": f"The digit is {prediction}"})
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)