Spaces:
Sleeping
Sleeping
import numpy as np | |
from PIL import Image | |
import torch | |
from fastapi import FastAPI, UploadFile, File | |
from fastapi.responses import JSONResponse | |
from pydantic import BaseModel | |
app = FastAPI() | |
# Load the pre-trained model | |
model_uri = "model.pth" | |
model = torch.load(model_uri) | |
# Define input schema for JSON requests | |
class ImageInput(BaseModel): | |
image_path: str | |
# Preprocess the image | |
def preprocess_image(image): | |
image = image.convert('L') # Convert to grayscale | |
image = image.resize((28, 28)) | |
image = np.array(image) / 255.0 # Normalize to [0, 1] | |
image = (image - 0.1307) / 0.3081 # Standardize | |
image = torch.tensor(image).unsqueeze(0).float() # Convert to tensor with batch dimension | |
return image | |
# Root endpoint | |
def greet_json(): | |
return {"Hello": "World!"} | |
# Predict endpoint for JSON input | |
async def predict_image(file: UploadFile = File(...)): | |
try: | |
# Read and preprocess the uploaded image | |
image = Image.open(file.file) | |
image = preprocess_image(image) | |
# Make prediction | |
model.eval() | |
with torch.no_grad(): | |
output = model(image) | |
prediction = output.argmax(dim=1).item() | |
return JSONResponse(content={"prediction": f"The digit is {prediction}"}) | |
except Exception as e: | |
return JSONResponse(content={"error": str(e)}, status_code=500) |