Update app.py
Browse files
app.py
CHANGED
@@ -71,17 +71,24 @@ def label_to_color_image(label):
|
|
71 |
if np.max(label) >= len(colormap):
|
72 |
raise ValueError("label value too large.")
|
73 |
return colormap[label]
|
74 |
-
|
75 |
-
#
|
76 |
def image_segmentation(input_img):
|
77 |
-
#
|
|
|
|
|
|
|
78 |
inputs = feature_extractor(images=input_img, return_tensors="pt")
|
79 |
outputs = model(**inputs)
|
80 |
pred_label = np.argmax(outputs.logits[0].numpy(), axis=0)
|
81 |
-
|
|
|
|
|
|
|
82 |
|
83 |
return pred_img
|
84 |
|
|
|
85 |
gr.Interface(
|
86 |
fn=image_segmentation,
|
87 |
inputs="image",
|
|
|
71 |
if np.max(label) >= len(colormap):
|
72 |
raise ValueError("label value too large.")
|
73 |
return colormap[label]
|
74 |
+
|
75 |
+
# Gradio ์ธํฐํ์ด์ค ์์ฑ
|
76 |
def image_segmentation(input_img):
|
77 |
+
# ์ด๋ฏธ์ง๋ฅผ ๋ํ์ด ๋ฐฐ์ด๋ก ๋ณํ
|
78 |
+
input_img = np.array(input_img)
|
79 |
+
|
80 |
+
# ์ด๋ฏธ์ง ์ธ๊ทธ๋ฉํ
์ด์
์ํ
|
81 |
inputs = feature_extractor(images=input_img, return_tensors="pt")
|
82 |
outputs = model(**inputs)
|
83 |
pred_label = np.argmax(outputs.logits[0].numpy(), axis=0)
|
84 |
+
|
85 |
+
# ์์ธก๋ ๋ ์ด๋ธ์ ์ด๋ฏธ์ง๋ก ๋ณํ
|
86 |
+
pred_img = Image.fromarray(pred_label.astype(np.uint8), 'P')
|
87 |
+
pred_img.putpalette(ade_palette()) # ํ๋ ํธ ์ค์
|
88 |
|
89 |
return pred_img
|
90 |
|
91 |
+
# Gradio ์ธํฐํ์ด์ค ์์ฑ
|
92 |
gr.Interface(
|
93 |
fn=image_segmentation,
|
94 |
inputs="image",
|