Spaces:
Sleeping
Sleeping
code and model
Browse files- .DS_Store +0 -0
- .gitattributes +1 -0
- Dockerfile +17 -0
- README.md +0 -11
- requirements.txt +4 -0
- resnet_mnist_cpu.pth +3 -0
- server.py +104 -0
- static/script.js +75 -0
- templates/index.html +40 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
resnet_mnist_cpu.pth filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Use a lightweight Python image
|
2 |
+
FROM python:3.9-slim
|
3 |
+
|
4 |
+
# Set working directory
|
5 |
+
WORKDIR /app
|
6 |
+
|
7 |
+
# Copy all files into container
|
8 |
+
COPY . .
|
9 |
+
|
10 |
+
# Install dependencies
|
11 |
+
RUN pip install -r requirements.txt
|
12 |
+
|
13 |
+
# Expose port 7860 for Hugging Face Spaces
|
14 |
+
EXPOSE 7860
|
15 |
+
|
16 |
+
# Run Flask app
|
17 |
+
CMD ["python", "server.py"]
|
README.md
CHANGED
@@ -1,11 +0,0 @@
|
|
1 |
-
---
|
2 |
-
title: Numberclassifier
|
3 |
-
emoji: 😻
|
4 |
-
colorFrom: purple
|
5 |
-
colorTo: pink
|
6 |
-
sdk: docker
|
7 |
-
pinned: false
|
8 |
-
short_description: 'Code Tutorial '
|
9 |
-
---
|
10 |
-
|
11 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
flask
|
4 |
+
flask-cors
|
resnet_mnist_cpu.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b30cdcd031021b676ff2dbae74c8e63a903533257d195c58fcd79733197c2a3c
|
3 |
+
size 44762627
|
server.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torchvision.transforms as transforms
|
4 |
+
from flask import Flask, request, jsonify, render_template
|
5 |
+
from PIL import Image
|
6 |
+
import io
|
7 |
+
from flask_cors import CORS
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
|
11 |
+
class ResBlock(nn.Module):
|
12 |
+
def __init__(self, input_features, output_features):
|
13 |
+
super(ResBlock, self).__init__()
|
14 |
+
self.stride = 1 if input_features == output_features else 2
|
15 |
+
self.features = nn.Sequential(
|
16 |
+
nn.Conv2d(input_features, output_features,
|
17 |
+
kernel_size=3, stride=self.stride, padding=1, bias=False),
|
18 |
+
nn.BatchNorm2d(output_features),
|
19 |
+
nn.ReLU(inplace=True),
|
20 |
+
nn.Conv2d(output_features, output_features,
|
21 |
+
kernel_size=3, stride=1, padding=1, bias=False),
|
22 |
+
nn.BatchNorm2d(output_features)
|
23 |
+
)
|
24 |
+
|
25 |
+
self.shortcut = nn.Sequential(nn.Identity())
|
26 |
+
if input_features != output_features:
|
27 |
+
self.shortcut = nn.Sequential(
|
28 |
+
nn.Conv2d(input_features, output_features, kernel_size=1, stride=self.stride, bias=False))
|
29 |
+
|
30 |
+
def forward(self, x):
|
31 |
+
residual = self.shortcut(x)
|
32 |
+
x = self.features(x)
|
33 |
+
x += residual
|
34 |
+
x = F.relu(x, inplace=True)
|
35 |
+
return x
|
36 |
+
|
37 |
+
class Resnet18(nn.Module):
|
38 |
+
def __init__(self, num_of_classes=10):
|
39 |
+
super(Resnet18, self).__init__()
|
40 |
+
self.features = nn.Sequential(
|
41 |
+
nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False),
|
42 |
+
nn.BatchNorm2d(64),
|
43 |
+
nn.ReLU(inplace=True),
|
44 |
+
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
|
45 |
+
|
46 |
+
ResBlock(64, 64),
|
47 |
+
ResBlock(64, 64),
|
48 |
+
|
49 |
+
ResBlock(64, 128),
|
50 |
+
ResBlock(128, 128),
|
51 |
+
|
52 |
+
ResBlock(128, 256),
|
53 |
+
ResBlock(256, 256),
|
54 |
+
|
55 |
+
ResBlock(256, 512),
|
56 |
+
ResBlock(512, 512),
|
57 |
+
|
58 |
+
nn.AdaptiveAvgPool2d((1, 1))
|
59 |
+
)
|
60 |
+
self.classifier = nn.Sequential(
|
61 |
+
nn.Linear(512, num_of_classes)
|
62 |
+
)
|
63 |
+
|
64 |
+
def forward(self, x):
|
65 |
+
x = self.features(x)
|
66 |
+
x = torch.flatten(x, 1)
|
67 |
+
x = self.classifier(x)
|
68 |
+
return x
|
69 |
+
|
70 |
+
# Load model
|
71 |
+
device = "cpu"
|
72 |
+
model = Resnet18().to(device)
|
73 |
+
model.load_state_dict(torch.load("resnet_mnist_cpu.pth"))
|
74 |
+
model.eval()
|
75 |
+
|
76 |
+
# Define image preprocessing
|
77 |
+
transform = transforms.Compose([transforms.Grayscale(), transforms.Resize((224, 224)), transforms.ToTensor()])
|
78 |
+
|
79 |
+
# Initialize Flask app
|
80 |
+
app = Flask(__name__)
|
81 |
+
CORS(app)
|
82 |
+
|
83 |
+
@app.route("/")
|
84 |
+
def home():
|
85 |
+
return render_template("index.html")
|
86 |
+
|
87 |
+
# Route to handle image predictions
|
88 |
+
@app.route("/predict", methods=["POST"])
|
89 |
+
def predict():
|
90 |
+
file = request.files["image"].read()
|
91 |
+
image = Image.open(io.BytesIO(file))
|
92 |
+
image = transform(image).unsqueeze(0).to(device)
|
93 |
+
|
94 |
+
with torch.no_grad():
|
95 |
+
outputs = model(image)
|
96 |
+
_, predicted = torch.max(outputs, 1)
|
97 |
+
|
98 |
+
class_labels = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9" ] # Modify based on your dataset
|
99 |
+
prediction = class_labels[predicted.item()]
|
100 |
+
|
101 |
+
return jsonify({"prediction": prediction})
|
102 |
+
|
103 |
+
if __name__ == "__main__":
|
104 |
+
app.run(host="0.0.0.0", port=7860)
|
static/script.js
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
let canvas;
|
2 |
+
let ctx;
|
3 |
+
document.addEventListener("DOMContentLoaded", () => {
|
4 |
+
canvas = document.getElementById("drawingCanvas");
|
5 |
+
ctx = canvas.getContext("2d");
|
6 |
+
let isDrawing = false;
|
7 |
+
|
8 |
+
canvas.addEventListener("mousedown", () => isDrawing = true);
|
9 |
+
canvas.addEventListener("mouseup", () => isDrawing = false);
|
10 |
+
canvas.addEventListener("mousemove", draw);
|
11 |
+
|
12 |
+
function draw(event) {
|
13 |
+
if (!isDrawing) return;
|
14 |
+
ctx.fillStyle = "black";
|
15 |
+
ctx.fillRect(event.offsetX, event.offsetY, 20, 20);
|
16 |
+
}
|
17 |
+
});
|
18 |
+
|
19 |
+
function clearCanvas() {
|
20 |
+
ctx.clearRect(0, 0, canvas.width, canvas.height);
|
21 |
+
document.getElementById("result").innerText = "";
|
22 |
+
}
|
23 |
+
|
24 |
+
// Inverting the images as the training images had a black background with white text,
|
25 |
+
// opposite of what we get from the canvas
|
26 |
+
function invertCanvasColors(inputCanvas) {
|
27 |
+
const tempCanvas = document.createElement("canvas");
|
28 |
+
const tempCtx = tempCanvas.getContext("2d");
|
29 |
+
|
30 |
+
// Set the size of the temporary canvas to match the original canvas
|
31 |
+
tempCanvas.width = inputCanvas.width;
|
32 |
+
tempCanvas.height = inputCanvas.height;
|
33 |
+
|
34 |
+
// Draw the original canvas onto the temporary canvas
|
35 |
+
tempCtx.drawImage(inputCanvas, 0, 0);
|
36 |
+
|
37 |
+
// Get the pixel data of the image
|
38 |
+
const imageData = tempCtx.getImageData(0, 0, tempCanvas.width, tempCanvas.height);
|
39 |
+
const data = imageData.data;
|
40 |
+
|
41 |
+
// Invert each pixel's RGB values
|
42 |
+
for (let i = 0; i < data.length; i += 4) {
|
43 |
+
data[i] = 255 - data[i]; // Red channel
|
44 |
+
data[i + 1] = 255 - data[i + 1]; // Green channel
|
45 |
+
data[i + 2] = 255 - data[i + 2]; // Blue channel
|
46 |
+
}
|
47 |
+
tempCtx.putImageData(imageData, 0, 0);
|
48 |
+
return tempCanvas;
|
49 |
+
}
|
50 |
+
|
51 |
+
function sendToServer() {
|
52 |
+
const invertedCanvas = invertCanvasColors(canvas);
|
53 |
+
let image = invertedCanvas.toDataURL("image/png");
|
54 |
+
let blob = dataURItoBlob(image);
|
55 |
+
let formData = new FormData();
|
56 |
+
formData.append("image", blob, "drawing.png");
|
57 |
+
|
58 |
+
fetch("https://ramachandrankulothungan-digit-doodle-recognition.hf.space/predict", {
|
59 |
+
method: "POST",
|
60 |
+
body: formData
|
61 |
+
})
|
62 |
+
.then(response => response.json())
|
63 |
+
.then(data => document.getElementById("result").innerText = "Prediction: " + data.prediction)
|
64 |
+
.catch(error => console.error("Error:", error));
|
65 |
+
}
|
66 |
+
|
67 |
+
function dataURItoBlob(dataURI) {
|
68 |
+
let byteString = atob(dataURI.split(",")[1]);
|
69 |
+
let arrayBuffer = new ArrayBuffer(byteString.length);
|
70 |
+
let uint8Array = new Uint8Array(arrayBuffer);
|
71 |
+
for (let i = 0; i < byteString.length; i++) {
|
72 |
+
uint8Array[i] = byteString.charCodeAt(i);
|
73 |
+
}
|
74 |
+
return new Blob([uint8Array], { type: "image/png" });
|
75 |
+
}
|
templates/index.html
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8">
|
5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
6 |
+
<title>Digit Doodle Recognition</title>
|
7 |
+
<script src="https://cdn.tailwindcss.com"></script>
|
8 |
+
<script src="{{ url_for('static', filename='script.js') }}"></script>
|
9 |
+
</head>
|
10 |
+
<body class="bg-gray-900 text-gray-200 font-sans min-h-screen">
|
11 |
+
<div class="container mx-auto p-6">
|
12 |
+
<!-- Title -->
|
13 |
+
<h1 class="text-4xl font-bold text-center mb-6">Digit Doodle Recognition</h1>
|
14 |
+
|
15 |
+
<!-- Canvas Section -->
|
16 |
+
<div class="flex flex-col items-center gap-4">
|
17 |
+
<!-- Canvas -->
|
18 |
+
<canvas id="drawingCanvas" width="500" height="500" class="bg-gray-800 border border-gray-700 shadow-md rounded-lg"></canvas>
|
19 |
+
|
20 |
+
<!-- Buttons for Canvas -->
|
21 |
+
<div class="flex gap-4">
|
22 |
+
<button
|
23 |
+
onclick="clearCanvas()"
|
24 |
+
class="bg-blue-600 hover:bg-blue-700 text-white font-medium py-2 px-4 rounded shadow-md transition duration-300">
|
25 |
+
Clear
|
26 |
+
</button>
|
27 |
+
<button
|
28 |
+
onclick="sendToServer()"
|
29 |
+
class="bg-green-600 hover:bg-green-700 text-white font-medium py-2 px-4 rounded shadow-md transition duration-300">
|
30 |
+
Predict
|
31 |
+
</button>
|
32 |
+
</div>
|
33 |
+
|
34 |
+
<!-- Prediction Result -->
|
35 |
+
<p id="result" class="text-lg font-medium text-green-400 mt-2"></p>
|
36 |
+
</div>
|
37 |
+
|
38 |
+
</div>
|
39 |
+
</body>
|
40 |
+
</html>
|