sergiopaniego HF Staff commited on
Commit
19a9827
·
1 Parent(s): 0b822c2

Updated app

Browse files
Files changed (1) hide show
  1. app.py +6 -2
app.py CHANGED
@@ -18,8 +18,10 @@ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
18
  COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
19
  [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
20
 
21
- def generate_output_figure(in_pil_img, in_results):
22
  plt.figure(figsize=(16, 10))
 
 
23
  ax = plt.gca()
24
 
25
  for prediction in in_results:
@@ -44,7 +46,9 @@ def detect(image):
44
  with torch.no_grad():
45
  outputs = model(**encoding)
46
 
47
- output_figure = generate_output_figure(image, outputs)
 
 
48
 
49
  buf = io.BytesIO()
50
  output_figure.savefig(buf, bbox_inches='tight')
 
18
  COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
19
  [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
20
 
21
+ def get_figure(in_pil_img, in_results):
22
  plt.figure(figsize=(16, 10))
23
+ plt.imshow(in_pil_img)
24
+ #pyplot.gcf()
25
  ax = plt.gca()
26
 
27
  for prediction in in_results:
 
46
  with torch.no_grad():
47
  outputs = model(**encoding)
48
 
49
+ print(outputs)
50
+
51
+ output_figure = get_figure(image, outputs)
52
 
53
  buf = io.BytesIO()
54
  output_figure.savefig(buf, bbox_inches='tight')