Update app.py
Browse files
app.py
CHANGED
@@ -2,41 +2,96 @@ import gradio as gr
|
|
2 |
import torch
|
3 |
from torchvision import transforms
|
4 |
from PIL import Image
|
5 |
-
import
|
|
|
|
|
6 |
|
7 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
9 |
-
|
10 |
-
model
|
|
|
|
|
|
|
|
|
|
|
11 |
model.to(device)
|
12 |
model.eval()
|
13 |
|
14 |
-
# Define
|
15 |
transform = transforms.Compose([
|
16 |
-
transforms.Resize((224, 224)),
|
17 |
-
transforms.ToTensor(),
|
18 |
-
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
19 |
])
|
20 |
|
21 |
-
# Define
|
22 |
def predict(image):
|
23 |
image = Image.fromarray(image) # Convert numpy array to PIL Image
|
24 |
image = transform(image).unsqueeze(0) # Apply transformations and add batch dimension
|
25 |
image = image.to(device)
|
26 |
-
|
27 |
with torch.no_grad():
|
28 |
outputs = model(image)
|
29 |
_, predicted = torch.max(outputs, 1)
|
30 |
-
|
31 |
-
#
|
32 |
class_names = ['OUTSWING', 'STRAIGHT', 'BACK_OF_HAND', 'CARROM', 'CROSSSEAM',
|
33 |
'GOOGLY', 'INSWING', 'KNUCKLE', 'LEGSPIN', 'OFFSPIN']
|
34 |
predicted_label = class_names[predicted.item()]
|
35 |
-
|
36 |
return predicted_label
|
37 |
|
38 |
# Create the Gradio Interface
|
39 |
-
iface = gr.Interface(fn=predict,
|
40 |
inputs=gr.Image(type="numpy"), # Accepts image input
|
41 |
outputs=gr.Text(), # Output the predicted class label
|
42 |
live=True) # live=True enables prediction while image is being uploaded
|
|
|
2 |
import torch
|
3 |
from torchvision import transforms
|
4 |
from PIL import Image
|
5 |
+
import torch.nn as nn
|
6 |
+
import os
|
7 |
+
from torchvision import models
|
8 |
|
9 |
+
# Custom Residual Block
|
10 |
+
class ResidualBlock(nn.Module):
|
11 |
+
def __init__(self, in_channels, out_channels):
|
12 |
+
super(ResidualBlock, self).__init__()
|
13 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
|
14 |
+
self.bn1 = nn.BatchNorm2d(out_channels)
|
15 |
+
self.relu = nn.ReLU()
|
16 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
|
17 |
+
self.bn2 = nn.BatchNorm2d(out_channels)
|
18 |
+
|
19 |
+
# Skip connection
|
20 |
+
self.skip = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
21 |
+
self.skip_bn = nn.BatchNorm2d(out_channels)
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
identity = self.skip(x)
|
25 |
+
x = self.relu(self.bn1(self.conv1(x)))
|
26 |
+
x = self.bn2(self.conv2(x))
|
27 |
+
x += identity # Add skip connection
|
28 |
+
x = self.relu(x)
|
29 |
+
return x
|
30 |
+
|
31 |
+
# EfficientNet Model with Novelty (Residual Block)
|
32 |
+
class EfficientNetWithNovelty(nn.Module):
|
33 |
+
def __init__(self, num_classes):
|
34 |
+
super(EfficientNetWithNovelty, self).__init__()
|
35 |
+
|
36 |
+
# Load pre-trained EfficientNet-B0 model
|
37 |
+
self.model = models.efficientnet_b0(pretrained=True)
|
38 |
+
|
39 |
+
# Modify the final classifier layer for our number of classes
|
40 |
+
self.model.classifier[1] = nn.Linear(self.model.classifier[1].in_features, num_classes)
|
41 |
+
|
42 |
+
# Add the custom residual block after the EfficientNet feature extractor
|
43 |
+
self.residual_block = ResidualBlock(1280, 1280) # 1280 is the output channels from EfficientNet B0
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
# Pass through the EfficientNet feature extractor
|
47 |
+
x = self.model.features(x) # Access feature extraction part
|
48 |
+
|
49 |
+
# Pass through the custom residual block
|
50 |
+
x = self.residual_block(x)
|
51 |
+
|
52 |
+
# Flatten the output to feed into the classifier
|
53 |
+
x = x.mean([2, 3]) # Global Average Pooling
|
54 |
+
x = self.model.classifier(x) # Pass through the final classifier layer
|
55 |
+
|
56 |
+
return x
|
57 |
+
|
58 |
+
# Load the model and weights
|
59 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
60 |
+
|
61 |
+
# Update this path with your model path
|
62 |
+
model_path = 'final_model.pth'
|
63 |
+
num_classes = 10 # Assuming you have 10 classes, update based on your dataset
|
64 |
+
|
65 |
+
model = EfficientNetWithNovelty(num_classes)
|
66 |
+
model.load_state_dict(torch.load(model_path, map_location=device))
|
67 |
model.to(device)
|
68 |
model.eval()
|
69 |
|
70 |
+
# Define image transformations (same as during training)
|
71 |
transform = transforms.Compose([
|
72 |
+
transforms.Resize((224, 224)),
|
73 |
+
transforms.ToTensor(),
|
74 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
75 |
])
|
76 |
|
77 |
+
# Define the prediction function for Gradio
|
78 |
def predict(image):
|
79 |
image = Image.fromarray(image) # Convert numpy array to PIL Image
|
80 |
image = transform(image).unsqueeze(0) # Apply transformations and add batch dimension
|
81 |
image = image.to(device)
|
82 |
+
|
83 |
with torch.no_grad():
|
84 |
outputs = model(image)
|
85 |
_, predicted = torch.max(outputs, 1)
|
86 |
+
|
87 |
+
# Class names for your classification
|
88 |
class_names = ['OUTSWING', 'STRAIGHT', 'BACK_OF_HAND', 'CARROM', 'CROSSSEAM',
|
89 |
'GOOGLY', 'INSWING', 'KNUCKLE', 'LEGSPIN', 'OFFSPIN']
|
90 |
predicted_label = class_names[predicted.item()]
|
|
|
91 |
return predicted_label
|
92 |
|
93 |
# Create the Gradio Interface
|
94 |
+
iface = gr.Interface(fn=predict,
|
95 |
inputs=gr.Image(type="numpy"), # Accepts image input
|
96 |
outputs=gr.Text(), # Output the predicted class label
|
97 |
live=True) # live=True enables prediction while image is being uploaded
|