import gradio as gr import torch from torchvision import transforms, models from PIL import Image import torch.nn as nn class CustomEfficientNet(nn.Module): def __init__(self, num_classes, num_layers, neurons_per_layer): super(CustomEfficientNet, self).__init__() self.base_model = models.efficientnet_b0(pretrained=True) in_features = self.base_model.classifier[1].in_features self.base_model.classifier = nn.Identity() # Remove the existing classifier # Define custom layers layers = [] for _ in range(num_layers): layers.append(nn.Linear(in_features, neurons_per_layer)) layers.append(nn.ReLU()) layers.append(nn.Dropout(0.5)) in_features = neurons_per_layer layers.append(nn.Linear(neurons_per_layer, num_classes)) self.custom_classifier = nn.Sequential(*layers) def forward(self, x): x = self.base_model(x) x = x.view(x.size(0), -1) # Flatten the tensor x = self.custom_classifier(x) return x def create_model(num_classes, num_layers, neurons_per_layer): model = CustomEfficientNet(num_classes, num_layers, neurons_per_layer) return model def load_model(path, num_classes, num_layers, neurons_per_layer): model = create_model(num_classes, num_layers, neurons_per_layer) model.load_state_dict(torch.load(path, map_location=torch.device('cpu'))) model.eval() return model # Parameters num_classes = 52 num_layers = 3 neurons_per_layer = 1024 # Load the model model = load_model('card_classification_model.pth', num_classes, num_layers, neurons_per_layer) # Define the transformation transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # Class names class_names = ['Coeur 1', 'Coeur 10', 'Coeur 2', 'Coeur 3', 'Coeur 4', 'Coeur 5', 'Coeur 6', 'Coeur 7', 'Coeur 8', 'Coeur 9', 'Coeur Dame', 'Coeur Roi', 'Coeur Valet', 'Pique 1', 'Pique 10', 'Pique 2', 'Pique 3', 'Pique 4', 'Pique 5', 'Pique 6', 'Pique 7', 'Pique 8', 'Pique 9', 'Pique Dame', 'Pique Roi', 'Pique Valet', 'Trefle 1', 'Trefle 10', 'Trefle 2', 'Trefle 3', 'Trefle 4', 'Trefle 5', 'Trefle 6', 'Trefle 7', 'Trefle 8', 'Trefle 9', 'Trefle Dame', 'Trefle Roi', 'Trefle Valet', 'carreau 1', 'carreau 10', 'carreau 2', 'carreau 3', 'carreau 4', 'carreau 5', 'carreau 6', 'carreau 7', 'carreau 8', 'carreau 9', 'carreau Dame', 'carreau Roi', 'carreau Valet'] def predict(image): image = transform(image).unsqueeze(0) with torch.no_grad(): outputs = model(image) _, predicted = torch.max(outputs, 1) return class_names[predicted[0]] # Create the Gradio interface iface = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs="label", description="Upload an image to classify" ) iface.launch()