sivakum4 commited on
Commit
bbea32b
·
1 Parent(s): 90d6c8e

code and model

Browse files
Files changed (9) hide show
  1. .DS_Store +0 -0
  2. .gitattributes +1 -0
  3. Dockerfile +17 -0
  4. README.md +0 -11
  5. requirements.txt +4 -0
  6. resnet_mnist_cpu.pth +3 -0
  7. server.py +104 -0
  8. static/script.js +75 -0
  9. templates/index.html +40 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ resnet_mnist_cpu.pth filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use a lightweight Python image
2
+ FROM python:3.9-slim
3
+
4
+ # Set working directory
5
+ WORKDIR /app
6
+
7
+ # Copy all files into container
8
+ COPY . .
9
+
10
+ # Install dependencies
11
+ RUN pip install -r requirements.txt
12
+
13
+ # Expose port 7860 for Hugging Face Spaces
14
+ EXPOSE 7860
15
+
16
+ # Run Flask app
17
+ CMD ["python", "server.py"]
README.md CHANGED
@@ -1,11 +0,0 @@
1
- ---
2
- title: Numberclassifier
3
- emoji: 😻
4
- colorFrom: purple
5
- colorTo: pink
6
- sdk: docker
7
- pinned: false
8
- short_description: 'Code Tutorial '
9
- ---
10
-
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ flask
4
+ flask-cors
resnet_mnist_cpu.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b30cdcd031021b676ff2dbae74c8e63a903533257d195c58fcd79733197c2a3c
3
+ size 44762627
server.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.transforms as transforms
4
+ from flask import Flask, request, jsonify, render_template
5
+ from PIL import Image
6
+ import io
7
+ from flask_cors import CORS
8
+ import torch.nn.functional as F
9
+
10
+
11
+ class ResBlock(nn.Module):
12
+ def __init__(self, input_features, output_features):
13
+ super(ResBlock, self).__init__()
14
+ self.stride = 1 if input_features == output_features else 2
15
+ self.features = nn.Sequential(
16
+ nn.Conv2d(input_features, output_features,
17
+ kernel_size=3, stride=self.stride, padding=1, bias=False),
18
+ nn.BatchNorm2d(output_features),
19
+ nn.ReLU(inplace=True),
20
+ nn.Conv2d(output_features, output_features,
21
+ kernel_size=3, stride=1, padding=1, bias=False),
22
+ nn.BatchNorm2d(output_features)
23
+ )
24
+
25
+ self.shortcut = nn.Sequential(nn.Identity())
26
+ if input_features != output_features:
27
+ self.shortcut = nn.Sequential(
28
+ nn.Conv2d(input_features, output_features, kernel_size=1, stride=self.stride, bias=False))
29
+
30
+ def forward(self, x):
31
+ residual = self.shortcut(x)
32
+ x = self.features(x)
33
+ x += residual
34
+ x = F.relu(x, inplace=True)
35
+ return x
36
+
37
+ class Resnet18(nn.Module):
38
+ def __init__(self, num_of_classes=10):
39
+ super(Resnet18, self).__init__()
40
+ self.features = nn.Sequential(
41
+ nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False),
42
+ nn.BatchNorm2d(64),
43
+ nn.ReLU(inplace=True),
44
+ nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
45
+
46
+ ResBlock(64, 64),
47
+ ResBlock(64, 64),
48
+
49
+ ResBlock(64, 128),
50
+ ResBlock(128, 128),
51
+
52
+ ResBlock(128, 256),
53
+ ResBlock(256, 256),
54
+
55
+ ResBlock(256, 512),
56
+ ResBlock(512, 512),
57
+
58
+ nn.AdaptiveAvgPool2d((1, 1))
59
+ )
60
+ self.classifier = nn.Sequential(
61
+ nn.Linear(512, num_of_classes)
62
+ )
63
+
64
+ def forward(self, x):
65
+ x = self.features(x)
66
+ x = torch.flatten(x, 1)
67
+ x = self.classifier(x)
68
+ return x
69
+
70
+ # Load model
71
+ device = "cpu"
72
+ model = Resnet18().to(device)
73
+ model.load_state_dict(torch.load("resnet_mnist_cpu.pth"))
74
+ model.eval()
75
+
76
+ # Define image preprocessing
77
+ transform = transforms.Compose([transforms.Grayscale(), transforms.Resize((224, 224)), transforms.ToTensor()])
78
+
79
+ # Initialize Flask app
80
+ app = Flask(__name__)
81
+ CORS(app)
82
+
83
+ @app.route("/")
84
+ def home():
85
+ return render_template("index.html")
86
+
87
+ # Route to handle image predictions
88
+ @app.route("/predict", methods=["POST"])
89
+ def predict():
90
+ file = request.files["image"].read()
91
+ image = Image.open(io.BytesIO(file))
92
+ image = transform(image).unsqueeze(0).to(device)
93
+
94
+ with torch.no_grad():
95
+ outputs = model(image)
96
+ _, predicted = torch.max(outputs, 1)
97
+
98
+ class_labels = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9" ] # Modify based on your dataset
99
+ prediction = class_labels[predicted.item()]
100
+
101
+ return jsonify({"prediction": prediction})
102
+
103
+ if __name__ == "__main__":
104
+ app.run(host="0.0.0.0", port=7860)
static/script.js ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ let canvas;
2
+ let ctx;
3
+ document.addEventListener("DOMContentLoaded", () => {
4
+ canvas = document.getElementById("drawingCanvas");
5
+ ctx = canvas.getContext("2d");
6
+ let isDrawing = false;
7
+
8
+ canvas.addEventListener("mousedown", () => isDrawing = true);
9
+ canvas.addEventListener("mouseup", () => isDrawing = false);
10
+ canvas.addEventListener("mousemove", draw);
11
+
12
+ function draw(event) {
13
+ if (!isDrawing) return;
14
+ ctx.fillStyle = "black";
15
+ ctx.fillRect(event.offsetX, event.offsetY, 20, 20);
16
+ }
17
+ });
18
+
19
+ function clearCanvas() {
20
+ ctx.clearRect(0, 0, canvas.width, canvas.height);
21
+ document.getElementById("result").innerText = "";
22
+ }
23
+
24
+ // Inverting the images as the training images had a black background with white text,
25
+ // opposite of what we get from the canvas
26
+ function invertCanvasColors(inputCanvas) {
27
+ const tempCanvas = document.createElement("canvas");
28
+ const tempCtx = tempCanvas.getContext("2d");
29
+
30
+ // Set the size of the temporary canvas to match the original canvas
31
+ tempCanvas.width = inputCanvas.width;
32
+ tempCanvas.height = inputCanvas.height;
33
+
34
+ // Draw the original canvas onto the temporary canvas
35
+ tempCtx.drawImage(inputCanvas, 0, 0);
36
+
37
+ // Get the pixel data of the image
38
+ const imageData = tempCtx.getImageData(0, 0, tempCanvas.width, tempCanvas.height);
39
+ const data = imageData.data;
40
+
41
+ // Invert each pixel's RGB values
42
+ for (let i = 0; i < data.length; i += 4) {
43
+ data[i] = 255 - data[i]; // Red channel
44
+ data[i + 1] = 255 - data[i + 1]; // Green channel
45
+ data[i + 2] = 255 - data[i + 2]; // Blue channel
46
+ }
47
+ tempCtx.putImageData(imageData, 0, 0);
48
+ return tempCanvas;
49
+ }
50
+
51
+ function sendToServer() {
52
+ const invertedCanvas = invertCanvasColors(canvas);
53
+ let image = invertedCanvas.toDataURL("image/png");
54
+ let blob = dataURItoBlob(image);
55
+ let formData = new FormData();
56
+ formData.append("image", blob, "drawing.png");
57
+
58
+ fetch("https://ramachandrankulothungan-digit-doodle-recognition.hf.space/predict", {
59
+ method: "POST",
60
+ body: formData
61
+ })
62
+ .then(response => response.json())
63
+ .then(data => document.getElementById("result").innerText = "Prediction: " + data.prediction)
64
+ .catch(error => console.error("Error:", error));
65
+ }
66
+
67
+ function dataURItoBlob(dataURI) {
68
+ let byteString = atob(dataURI.split(",")[1]);
69
+ let arrayBuffer = new ArrayBuffer(byteString.length);
70
+ let uint8Array = new Uint8Array(arrayBuffer);
71
+ for (let i = 0; i < byteString.length; i++) {
72
+ uint8Array[i] = byteString.charCodeAt(i);
73
+ }
74
+ return new Blob([uint8Array], { type: "image/png" });
75
+ }
templates/index.html ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Digit Doodle Recognition</title>
7
+ <script src="https://cdn.tailwindcss.com"></script>
8
+ <script src="{{ url_for('static', filename='script.js') }}"></script>
9
+ </head>
10
+ <body class="bg-gray-900 text-gray-200 font-sans min-h-screen">
11
+ <div class="container mx-auto p-6">
12
+ <!-- Title -->
13
+ <h1 class="text-4xl font-bold text-center mb-6">Digit Doodle Recognition</h1>
14
+
15
+ <!-- Canvas Section -->
16
+ <div class="flex flex-col items-center gap-4">
17
+ <!-- Canvas -->
18
+ <canvas id="drawingCanvas" width="500" height="500" class="bg-gray-800 border border-gray-700 shadow-md rounded-lg"></canvas>
19
+
20
+ <!-- Buttons for Canvas -->
21
+ <div class="flex gap-4">
22
+ <button
23
+ onclick="clearCanvas()"
24
+ class="bg-blue-600 hover:bg-blue-700 text-white font-medium py-2 px-4 rounded shadow-md transition duration-300">
25
+ Clear
26
+ </button>
27
+ <button
28
+ onclick="sendToServer()"
29
+ class="bg-green-600 hover:bg-green-700 text-white font-medium py-2 px-4 rounded shadow-md transition duration-300">
30
+ Predict
31
+ </button>
32
+ </div>
33
+
34
+ <!-- Prediction Result -->
35
+ <p id="result" class="text-lg font-medium text-green-400 mt-2"></p>
36
+ </div>
37
+
38
+ </div>
39
+ </body>
40
+ </html>