import numpy as np from PIL import Image import torch from fastapi import FastAPI, UploadFile, File from fastapi.responses import JSONResponse from pydantic import BaseModel app = FastAPI() # Load the pre-trained model model_uri = "model.pth" model = torch.load(model_uri) # Define input schema for JSON requests class ImageInput(BaseModel): image_path: str # Preprocess the image def preprocess_image(image): image = image.convert('L') # Convert to grayscale image = image.resize((28, 28)) image = np.array(image) / 255.0 # Normalize to [0, 1] image = (image - 0.1307) / 0.3081 # Standardize image = torch.tensor(image).unsqueeze(0).float() # Convert to tensor with batch dimension return image # Root endpoint @app.get("/") def greet_json(): return {"Hello": "World!"} # Predict endpoint for JSON input @app.post("/predict") async def predict_image(file: UploadFile = File(...)): try: # Read and preprocess the uploaded image image = Image.open(file.file) image = preprocess_image(image) # Make prediction model.eval() with torch.no_grad(): output = model(image) prediction = output.argmax(dim=1).item() return JSONResponse(content={"prediction": f"The digit is {prediction}"}) except Exception as e: return JSONResponse(content={"error": str(e)}, status_code=500)