sergiopaniego HF Staff commited on
Commit
0b822c2
·
1 Parent(s): f0585ee

Updated app

Browse files
Files changed (1) hide show
  1. app.py +41 -4
app.py CHANGED
@@ -5,18 +5,55 @@ import torch
5
  from PIL import Image
6
  import requests
7
  from transformers import DetrImageProcessor
 
 
 
 
 
8
 
9
  processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- zero = torch.Tensor([0]).cuda()
12
- print(zero.device) # <-- 'cpu' 🤔
13
 
14
  @spaces.GPU
15
  def detect(image):
16
  encoding = processor(image, return_tensors='pt')
17
  print(encoding.keys())
18
- print(zero.device) # <-- 'cuda:0' 🤗
19
- return f"Hello {encoding.keys()} Tensor"
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  demo = gr.Interface(fn=detect, inputs=gr.Image(label="Input image", type="pil"), outputs=gr.Image(label="Output image", type="pil"))
22
  demo.launch()
 
5
  from PIL import Image
6
  import requests
7
  from transformers import DetrImageProcessor
8
+ from transformers import DetrForObjectDetection
9
+ from random import choice
10
+ import matplotlib.pyplot as plt
11
+ import io
12
+
13
 
14
  processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
15
+ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
16
+
17
+
18
+ COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
19
+ [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
20
+
21
+ def generate_output_figure(in_pil_img, in_results):
22
+ plt.figure(figsize=(16, 10))
23
+ ax = plt.gca()
24
+
25
+ for prediction in in_results:
26
+ selected_color = choice(COLORS)
27
+
28
+ x, y = prediction['box']['xmin'], prediction['box']['ymin'],
29
+ w, h = prediction['box']['xmax'] - prediction['box']['xmin'], prediction['box']['ymax'] - prediction['box']['ymin']
30
+
31
+ ax.add_patch(plt.Rectangle((x, y), w, h, fill=False, color=selected_color, linewidth=3))
32
+ ax.text(x, y, f"{prediction['label']}: {round(prediction['score']*100, 1)}%", fontdict=fdic)
33
+
34
+ plt.axis("off")
35
+
36
+ return plt.gcf()
37
 
 
 
38
 
39
  @spaces.GPU
40
  def detect(image):
41
  encoding = processor(image, return_tensors='pt')
42
  print(encoding.keys())
43
+
44
+ with torch.no_grad():
45
+ outputs = model(**encoding)
46
+
47
+ output_figure = generate_output_figure(image, outputs)
48
+
49
+ buf = io.BytesIO()
50
+ output_figure.savefig(buf, bbox_inches='tight')
51
+ buf.seek(0)
52
+ output_pil_img = Image.open(buf)
53
+
54
+ print(output_pil_img)
55
+
56
+ return output_pil_img
57
 
58
  demo = gr.Interface(fn=detect, inputs=gr.Image(label="Input image", type="pil"), outputs=gr.Image(label="Output image", type="pil"))
59
  demo.launch()