DimaML commited on
Commit
343d11b
·
verified ·
1 Parent(s): 1d045ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -14
app.py CHANGED
@@ -1,24 +1,42 @@
1
  import gradio as gr
2
  import numpy as np
 
 
 
3
 
4
- def gray(image):
5
- grayed_image = np.mean(image, 2)
6
- image_max = np.max(grayed_image)
7
 
8
- if image_max > 1:
9
- grayed_image = grayed_image / image_max
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- return grayed_image
12
 
13
  app = gr.Interface(
14
- gray,
15
- 'image',
16
- gr.Image(format='png'),
17
- examples=[
18
- './example1.jpg',
19
- './example2.jpg',
20
- './example3.jpg',
21
- ],
22
  live=True
23
  )
24
 
 
1
  import gradio as gr
2
  import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ from torchvision import transforms, dataset, models
6
 
7
+ transformer = models.ResNet18_Weights.IMAGENET1K_V1.transforms()
 
 
8
 
9
+ device = torch.device("cpu")
10
+ class_names = ['Anger', 'Disgust', 'Fear', 'Happy', 'Pain', 'Sad']
11
+ classes_count = len(class_names)
12
+
13
+ model = model.renset18(weights='DEFAULT').to(device)
14
+ model.fc = nn.Sequential(
15
+ nn.Linear(512, classes_count)
16
+ )
17
+ model.load_state_dict(torch.load('./model_param.pt', map_location=device), strict=False)
18
+
19
+ def predict(image):
20
+ image = transformer(image).unsqueeze(0).to(device)
21
+ model.eval()
22
+
23
+ with torch.inference_mode():
24
+ pred = torch.softmax(model(image), dim=1)
25
+
26
+ preds_and_labels = {class_names[i]: pred[0][i].item() for i in range(len(pred[0]))}
27
+
28
+ return preds_and_labels
29
 
 
30
 
31
  app = gr.Interface(
32
+ predict,
33
+ gr.Image(type='pil'),
34
+ gr.Label(label='Predictions', num_top_classes=classes_count),
35
+ #examples=[
36
+ # './example1.jpg',
37
+ # './example2.jpg',
38
+ # './example3.jpg',
39
+ #],
40
  live=True
41
  )
42