Update app.py
Browse files
app.py
CHANGED
@@ -72,33 +72,16 @@ def label_to_color_image(label):
|
|
72 |
raise ValueError("label value too large.")
|
73 |
return colormap[label]
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
input_img = np.array(input_img) # ์ด๋ฏธ์ง๋ฅผ ๋ํ์ด ๋ฐฐ์ด๋ก ๋ณํ
|
78 |
|
79 |
-
|
80 |
-
inputs = feature_extractor(images=input_img, return_tensors="pt")
|
81 |
-
outputs = model(**inputs)
|
82 |
-
pred_label = np.argmax(outputs.logits[0].numpy(), axis=0)
|
83 |
-
pred_img = label_to_color_image(pred_label)
|
84 |
|
85 |
-
|
|
|
|
|
86 |
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
if np.max(label) >= len(colormap):
|
92 |
-
raise ValueError("label value too large.")
|
93 |
-
return colormap[label]
|
94 |
-
|
95 |
-
# Gradio ์ธํฐํ์ด์ค ์์ฑ
|
96 |
-
iface = gr.Interface(
|
97 |
-
fn=image_segmentation,
|
98 |
-
inputs="image",
|
99 |
-
outputs="image"
|
100 |
-
)
|
101 |
-
|
102 |
-
# ์ธํฐํ์ด์ค ์คํ
|
103 |
-
iface.launch()
|
104 |
|
|
|
72 |
raise ValueError("label value too large.")
|
73 |
return colormap[label]
|
74 |
|
75 |
+
def draw_plot(pred_img, seg, show_seg=False):
|
76 |
+
fig = plt.figure(figsize=(20, 15))
|
|
|
77 |
|
78 |
+
grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
|
|
|
|
|
|
|
|
|
79 |
|
80 |
+
plt.subplot(grid_spec[0])
|
81 |
+
plt.imshow(pred_img)
|
82 |
+
plt.axis('off')
|
83 |
|
84 |
+
if show_seg:
|
85 |
+
unique_labels = np.unique(seg.numpy().astype("uint8"))
|
86 |
+
ax = plt.subplot(grid_spec[1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|