MyDockerAPI / app.py
BhuiyanMasum
Upload model files
1a9e80b
raw
history blame
1.41 kB
import numpy as np
from PIL import Image
import torch
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
from pydantic import BaseModel
app = FastAPI()
# Load the pre-trained model
model_uri = "model.pth"
model = torch.load(model_uri)
# Define input schema for JSON requests
class ImageInput(BaseModel):
image_path: str
# Preprocess the image
def preprocess_image(image):
image = image.convert('L') # Convert to grayscale
image = image.resize((28, 28))
image = np.array(image) / 255.0 # Normalize to [0, 1]
image = (image - 0.1307) / 0.3081 # Standardize
image = torch.tensor(image).unsqueeze(0).float() # Convert to tensor with batch dimension
return image
# Root endpoint
@app.get("/")
def greet_json():
return {"Hello": "World!"}
# Predict endpoint for JSON input
@app.post("/predict")
async def predict_image(file: UploadFile = File(...)):
try:
# Read and preprocess the uploaded image
image = Image.open(file.file)
image = preprocess_image(image)
# Make prediction
model.eval()
with torch.no_grad():
output = model(image)
prediction = output.argmax(dim=1).item()
return JSONResponse(content={"prediction": f"The digit is {prediction}"})
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)