File size: 3,023 Bytes
64a8cc7 40eea6c 64a8cc7 3228ea4 64a8cc7 d570434 64a8cc7 617dc9e 64a8cc7 617dc9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import gradio as gr
import torch
from torchvision import transforms, models
from PIL import Image
import torch.nn as nn
class CustomEfficientNet(nn.Module):
def __init__(self, num_classes, num_layers, neurons_per_layer):
super(CustomEfficientNet, self).__init__()
self.base_model = models.efficientnet_b0(pretrained=True)
in_features = self.base_model.classifier[1].in_features
self.base_model.classifier = nn.Identity() # Remove the existing classifier
# Define custom layers
layers = []
for _ in range(num_layers):
layers.append(nn.Linear(in_features, neurons_per_layer))
layers.append(nn.ReLU())
layers.append(nn.Dropout(0.5))
in_features = neurons_per_layer
layers.append(nn.Linear(neurons_per_layer, num_classes))
self.custom_classifier = nn.Sequential(*layers)
def forward(self, x):
x = self.base_model(x)
x = x.view(x.size(0), -1) # Flatten the tensor
x = self.custom_classifier(x)
return x
def create_model(num_classes, num_layers, neurons_per_layer):
model = CustomEfficientNet(num_classes, num_layers, neurons_per_layer)
return model
def load_model(path, num_classes, num_layers, neurons_per_layer):
model = create_model(num_classes, num_layers, neurons_per_layer)
model.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
model.eval()
return model
# Parameters
num_classes = 52
num_layers = 3
neurons_per_layer = 1024
# Load the model
model = load_model('card_classification_model.pth', num_classes, num_layers, neurons_per_layer)
# Define the transformation
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Class names
class_names = ['Coeur 1', 'Coeur 10', 'Coeur 2', 'Coeur 3', 'Coeur 4', 'Coeur 5', 'Coeur 6',
'Coeur 7', 'Coeur 8', 'Coeur 9', 'Coeur Dame', 'Coeur Roi', 'Coeur Valet', 'Pique 1',
'Pique 10', 'Pique 2', 'Pique 3', 'Pique 4', 'Pique 5', 'Pique 6', 'Pique 7', 'Pique 8',
'Pique 9', 'Pique Dame', 'Pique Roi', 'Pique Valet', 'Trefle 1', 'Trefle 10', 'Trefle 2',
'Trefle 3', 'Trefle 4', 'Trefle 5', 'Trefle 6', 'Trefle 7', 'Trefle 8', 'Trefle 9', 'Trefle Dame',
'Trefle Roi', 'Trefle Valet', 'carreau 1', 'carreau 10', 'carreau 2', 'carreau 3', 'carreau 4', 'carreau 5',
'carreau 6', 'carreau 7', 'carreau 8', 'carreau 9', 'carreau Dame', 'carreau Roi', 'carreau Valet']
def predict(image):
image = transform(image).unsqueeze(0)
with torch.no_grad():
outputs = model(image)
_, predicted = torch.max(outputs, 1)
return class_names[predicted[0]]
# Create the Gradio interface
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs="label",
description="Upload an image to classify"
)
iface.launch()
|