brightlembo commited on
Commit
f4b233e
·
verified ·
1 Parent(s): 22cdac0

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +51 -0
  2. class_names.json +5 -0
  3. efficientnet_b7_bestv1.pth +3 -0
  4. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+ from PIL import Image
4
+ import gradio as gr
5
+ import json
6
+
7
+ # Charger les noms des classes
8
+ with open("class_names.json", "r") as f:
9
+ class_names = json.load(f)
10
+
11
+ # Charger le modèle
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ model = torch.load("efficientnet_b7_best.pth", map_location=device)
14
+ model.eval() # Mode évaluation
15
+
16
+ # Définir la taille de l'image
17
+ image_size = (224, 224)
18
+
19
+ # Transformation pour l'image
20
+ class GrayscaleToRGB:
21
+ def __call__(self, img):
22
+ return img.convert("RGB")
23
+
24
+ valid_test_transforms = transforms.Compose([
25
+ transforms.Grayscale(num_output_channels=1),
26
+ transforms.Resize(image_size),
27
+ GrayscaleToRGB(), # Conversion en RGB
28
+ transforms.ToTensor(),
29
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
30
+ ])
31
+
32
+ # Fonction de prédiction
33
+ def predict_image(image):
34
+ image_tensor = valid_test_transforms(image).unsqueeze(0).to(device)
35
+ with torch.no_grad():
36
+ outputs = model(image_tensor)
37
+ _, predicted_class = torch.max(outputs, 1)
38
+ predicted_label = class_names[predicted_class.item()]
39
+ return predicted_label
40
+
41
+ # Interface Gradio
42
+ interface = gr.Interface(
43
+ fn=predict_image,
44
+ inputs=gr.Image(type="pil"),
45
+ outputs="text",
46
+ title="Prédiction d'images avec PyTorch",
47
+ description="Chargez une image pour obtenir une prédiction de classe."
48
+ )
49
+
50
+ if __name__ == "__main__":
51
+ interface.launch()
class_names.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [
2
+ "disgust",
3
+ "happy",
4
+ "sad"
5
+ ]
efficientnet_b7_bestv1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c67597a1fddfac92cf0b7dc0c1b18bc6e18b56cc01733dd1c360a3b5355958e
3
+ size 262038038
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ gradio
4
+ Pillow