BhuiyanMasum commited on
Commit
1a9e80b
·
1 Parent(s): 1e7060d

Upload model files

Browse files
Files changed (3) hide show
  1. app.py +45 -1
  2. model.pth +3 -0
  3. requirements.txt +4 -1
app.py CHANGED
@@ -1,8 +1,52 @@
1
- from fastapi import FastAPI
 
 
 
 
 
2
 
3
  app = FastAPI()
4
 
 
 
 
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  @app.get("/")
7
  def greet_json():
8
  return {"Hello": "World!"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ import torch
4
+ from fastapi import FastAPI, UploadFile, File
5
+ from fastapi.responses import JSONResponse
6
+ from pydantic import BaseModel
7
 
8
  app = FastAPI()
9
 
10
+ # Load the pre-trained model
11
+ model_uri = "model.pth"
12
+ model = torch.load(model_uri)
13
 
14
+
15
+ # Define input schema for JSON requests
16
+ class ImageInput(BaseModel):
17
+ image_path: str
18
+
19
+
20
+ # Preprocess the image
21
+ def preprocess_image(image):
22
+ image = image.convert('L') # Convert to grayscale
23
+ image = image.resize((28, 28))
24
+ image = np.array(image) / 255.0 # Normalize to [0, 1]
25
+ image = (image - 0.1307) / 0.3081 # Standardize
26
+ image = torch.tensor(image).unsqueeze(0).float() # Convert to tensor with batch dimension
27
+ return image
28
+
29
+
30
+ # Root endpoint
31
  @app.get("/")
32
  def greet_json():
33
  return {"Hello": "World!"}
34
+
35
+
36
+ # Predict endpoint for JSON input
37
+ @app.post("/predict")
38
+ async def predict_image(file: UploadFile = File(...)):
39
+ try:
40
+ # Read and preprocess the uploaded image
41
+ image = Image.open(file.file)
42
+ image = preprocess_image(image)
43
+
44
+ # Make prediction
45
+ model.eval()
46
+ with torch.no_grad():
47
+ output = model(image)
48
+ prediction = output.argmax(dim=1).item()
49
+
50
+ return JSONResponse(content={"prediction": f"The digit is {prediction}"})
51
+ except Exception as e:
52
+ return JSONResponse(content={"error": str(e)}, status_code=500)
model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5a25d6fbe70a15a02f24ff1e586b44b4fb0a626193293b44fb718a18851b9f12
3
+ size 445812
requirements.txt CHANGED
@@ -1,2 +1,5 @@
1
  fastapi
2
- uvicorn[standard]
 
 
 
 
1
  fastapi
2
+ uvicorn[standard]
3
+ numpy
4
+ Pillow
5
+ torch