Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torchvision.transforms as transforms | |
from flask import Flask, request, jsonify, render_template | |
from PIL import Image | |
import io | |
from flask_cors import CORS | |
import torch.nn.functional as F | |
class ResBlock(nn.Module): | |
def __init__(self, input_features, output_features): | |
super(ResBlock, self).__init__() | |
self.stride = 1 if input_features == output_features else 2 | |
self.features = nn.Sequential( | |
nn.Conv2d(input_features, output_features, | |
kernel_size=3, stride=self.stride, padding=1, bias=False), | |
nn.BatchNorm2d(output_features), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(output_features, output_features, | |
kernel_size=3, stride=1, padding=1, bias=False), | |
nn.BatchNorm2d(output_features) | |
) | |
self.shortcut = nn.Sequential(nn.Identity()) | |
if input_features != output_features: | |
self.shortcut = nn.Sequential( | |
nn.Conv2d(input_features, output_features, kernel_size=1, stride=self.stride, bias=False)) | |
def forward(self, x): | |
residual = self.shortcut(x) | |
x = self.features(x) | |
x += residual | |
x = F.relu(x, inplace=True) | |
return x | |
class Resnet18(nn.Module): | |
def __init__(self, num_of_classes=10): | |
super(Resnet18, self).__init__() | |
self.features = nn.Sequential( | |
nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False), | |
nn.BatchNorm2d(64), | |
nn.ReLU(inplace=True), | |
nn.MaxPool2d(kernel_size=3, stride=2, padding=1), | |
ResBlock(64, 64), | |
ResBlock(64, 64), | |
ResBlock(64, 128), | |
ResBlock(128, 128), | |
ResBlock(128, 256), | |
ResBlock(256, 256), | |
ResBlock(256, 512), | |
ResBlock(512, 512), | |
nn.AdaptiveAvgPool2d((1, 1)) | |
) | |
self.classifier = nn.Sequential( | |
nn.Linear(512, num_of_classes) | |
) | |
def forward(self, x): | |
x = self.features(x) | |
x = torch.flatten(x, 1) | |
x = self.classifier(x) | |
return x | |
# Load model | |
device = "cpu" | |
model = Resnet18().to(device) | |
model.load_state_dict(torch.load("resnet_mnist_cpu.pth")) | |
model.eval() | |
# Define image preprocessing | |
transform = transforms.Compose([transforms.Grayscale(), transforms.Resize((224, 224)), transforms.ToTensor()]) | |
# Initialize Flask app | |
app = Flask(__name__) | |
CORS(app) | |
def home(): | |
return render_template("index.html") | |
# Route to handle image predictions | |
def predict(): | |
file = request.files["image"].read() | |
image = Image.open(io.BytesIO(file)) | |
image = transform(image).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
outputs = model(image) | |
_, predicted = torch.max(outputs, 1) | |
class_labels = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9" ] # Modify based on your dataset | |
prediction = class_labels[predicted.item()] | |
return jsonify({"prediction": prediction}) | |
if __name__ == "__main__": | |
app.run(host="0.0.0.0", port=7860) | |