dakkoong commited on
Commit
7c9490c
ยท
1 Parent(s): 83b7fa3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -7
app.py CHANGED
@@ -83,18 +83,26 @@ def sepia(input_img):
83
  logits = tf.transpose(logits, [0, 2, 3, 1])
84
  logits = tf.image.resize(
85
  logits, input_img.size[::-1]
86
- )
87
  seg = tf.math.argmax(logits, axis=-1)[0]
88
 
89
- # Return segmentation label image instead of Matplotlib Figure
90
- return seg.numpy()
 
 
 
 
 
 
 
 
 
91
 
92
- # Gradio Interface ์„ค์ •
93
  demo = gr.Interface(fn=sepia,
94
  inputs=gr.Image(shape=(800, 600)),
95
- outputs=['label'], # 'plot'์—์„œ 'label'๋กœ ๋ณ€๊ฒฝ
96
  examples=["cityoutdoor-1.jpg", "cityoutdoor-2.jpg", "cityoutdoor-3.jpg"],
97
  allow_flagging='never')
98
 
99
- # Gradio ์‹คํ–‰
100
- demo.launch()
 
83
  logits = tf.transpose(logits, [0, 2, 3, 1])
84
  logits = tf.image.resize(
85
  logits, input_img.size[::-1]
86
+ )
87
  seg = tf.math.argmax(logits, axis=-1)[0]
88
 
89
+ color_seg = np.zeros(
90
+ (seg.shape[0], seg.shape[1], 3), dtype=np.uint8
91
+ )
92
+ for label, color in enumerate(colormap):
93
+ color_seg[seg.numpy() == label, :] = color
94
+
95
+ pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
96
+ pred_img = pred_img.astype(np.uint8)
97
+
98
+ fig = draw_plot(pred_img, seg)
99
+ return fig
100
 
 
101
  demo = gr.Interface(fn=sepia,
102
  inputs=gr.Image(shape=(800, 600)),
103
+ outputs=['plot'],
104
  examples=["cityoutdoor-1.jpg", "cityoutdoor-2.jpg", "cityoutdoor-3.jpg"],
105
  allow_flagging='never')
106
 
107
+
108
+ demo.launch()