import torch import torch.nn as nn import torchvision.transforms as transforms from flask import Flask, request, jsonify, render_template from PIL import Image import io from flask_cors import CORS import torch.nn.functional as F class ResBlock(nn.Module): def __init__(self, input_features, output_features): super(ResBlock, self).__init__() self.stride = 1 if input_features == output_features else 2 self.features = nn.Sequential( nn.Conv2d(input_features, output_features, kernel_size=3, stride=self.stride, padding=1, bias=False), nn.BatchNorm2d(output_features), nn.ReLU(inplace=True), nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(output_features) ) self.shortcut = nn.Sequential(nn.Identity()) if input_features != output_features: self.shortcut = nn.Sequential( nn.Conv2d(input_features, output_features, kernel_size=1, stride=self.stride, bias=False)) def forward(self, x): residual = self.shortcut(x) x = self.features(x) x += residual x = F.relu(x, inplace=True) return x class Resnet18(nn.Module): def __init__(self, num_of_classes=10): super(Resnet18, self).__init__() self.features = nn.Sequential( nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2, padding=1), ResBlock(64, 64), ResBlock(64, 64), ResBlock(64, 128), ResBlock(128, 128), ResBlock(128, 256), ResBlock(256, 256), ResBlock(256, 512), ResBlock(512, 512), nn.AdaptiveAvgPool2d((1, 1)) ) self.classifier = nn.Sequential( nn.Linear(512, num_of_classes) ) def forward(self, x): x = self.features(x) x = torch.flatten(x, 1) x = self.classifier(x) return x # Load model device = "cpu" model = Resnet18().to(device) model.load_state_dict(torch.load("resnet_mnist_cpu.pth")) model.eval() # Define image preprocessing transform = transforms.Compose([transforms.Grayscale(), transforms.Resize((224, 224)), transforms.ToTensor()]) # Initialize Flask app app = Flask(__name__) CORS(app) @app.route("/") def home(): return render_template("index.html") # Route to handle image predictions @app.route("/predict", methods=["POST"]) def predict(): file = request.files["image"].read() image = Image.open(io.BytesIO(file)) image = transform(image).unsqueeze(0).to(device) with torch.no_grad(): outputs = model(image) _, predicted = torch.max(outputs, 1) class_labels = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9" ] # Modify based on your dataset prediction = class_labels[predicted.item()] return jsonify({"prediction": prediction}) if __name__ == "__main__": app.run(host="0.0.0.0", port=7860)