drFarid commited on
Commit
64a8cc7
·
verified ·
1 Parent(s): 4ad0a70

first commit - create app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -0
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+
6
+ class CustomEfficientNet(nn.Module):
7
+ def __init__(self, num_classes, num_layers, neurons_per_layer):
8
+ super(CustomEfficientNet, self).__init__()
9
+ self.base_model = models.efficientnet_b0(pretrained=True)
10
+ in_features = self.base_model.classifier[1].in_features
11
+ self.base_model.classifier = nn.Identity() # Remove the existing classifier
12
+
13
+ # Define custom layers
14
+ layers = []
15
+ for _ in range(num_layers):
16
+ layers.append(nn.Linear(in_features, neurons_per_layer))
17
+ layers.append(nn.ReLU())
18
+ layers.append(nn.Dropout(0.5))
19
+ in_features = neurons_per_layer
20
+
21
+ layers.append(nn.Linear(neurons_per_layer, num_classes))
22
+ self.custom_classifier = nn.Sequential(*layers)
23
+
24
+ def forward(self, x):
25
+ x = self.base_model(x)
26
+ x = x.view(x.size(0), -1) # Flatten the tensor
27
+ x = self.custom_classifier(x)
28
+ return x
29
+
30
+ def create_model(num_classes, num_layers, neurons_per_layer):
31
+ model = CustomEfficientNet(num_classes, num_layers, neurons_per_layer)
32
+ return model
33
+
34
+ def load_model(path, num_classes, num_layers, neurons_per_layer):
35
+ model = create_model(num_classes, num_layers, neurons_per_layer)
36
+ model.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
37
+ model.eval()
38
+ return model
39
+
40
+ # Parameters
41
+ num_classes = 52
42
+ num_layers = 3
43
+ neurons_per_layer = 1024
44
+
45
+ # Load the model
46
+ model = load_model('card_classification_model.pth', num_classes, num_layers, neurons_per_layer)
47
+
48
+ # Define the transformation
49
+ transform = transforms.Compose([
50
+ transforms.Resize((224, 224)),
51
+ transforms.ToTensor(),
52
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
53
+ ])
54
+
55
+ # Class names
56
+ class_names = ['class1', 'class2', ..., 'class52'] # Replace with actual class names
57
+
58
+ def predict(image):
59
+ image = transform(image).unsqueeze(0)
60
+ with torch.no_grad():
61
+ outputs = model(image)
62
+ _, predicted = torch.max(outputs, 1)
63
+ return class_names[predicted[0]]
64
+
65
+ # Create the Gradio interface
66
+ iface = gr.Interface(
67
+ fn=predict,
68
+ inputs=gr.Image(type="pil"),
69
+ outputs="label",
70
+ description="Upload an image to classify"
71
+ )
72
+
73
+ iface.launch()