EUNSEO56 commited on
Commit
a115d31
ยท
1 Parent(s): 310cd05

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -9
app.py CHANGED
@@ -72,27 +72,33 @@ def label_to_color_image(label):
72
  raise ValueError("label value too large.")
73
  return colormap[label]
74
 
75
- # Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ƒ์„ฑ
76
  def image_segmentation(input_img):
77
- # ์ด๋ฏธ์ง€๋ฅผ ๋„˜ํŒŒ์ด ๋ฐฐ์—ด๋กœ ๋ณ€ํ™˜
78
- input_img = np.array(input_img)
79
 
80
  # ์ด๋ฏธ์ง€ ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜ ์ˆ˜ํ–‰
81
  inputs = feature_extractor(images=input_img, return_tensors="pt")
82
  outputs = model(**inputs)
83
  pred_label = np.argmax(outputs.logits[0].numpy(), axis=0)
84
-
85
- # ์˜ˆ์ธก๋œ ๋ ˆ์ด๋ธ”์„ ์ด๋ฏธ์ง€๋กœ ๋ณ€ํ™˜
86
- pred_img = Image.fromarray(pred_label.astype(np.uint8), 'P')
87
- pred_img.putpalette(ade_palette()) # ํŒ”๋ ˆํŠธ ์„ค์ •
88
 
89
  return pred_img
90
 
 
 
 
 
 
 
 
 
91
  # Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ƒ์„ฑ
92
- gr.Interface(
93
  fn=image_segmentation,
94
  inputs="image",
95
  outputs="image"
96
- ).launch()
97
 
 
 
98
 
 
72
  raise ValueError("label value too large.")
73
  return colormap[label]
74
 
75
+ # Gradio ์ธํ„ฐํŽ˜์ด์Šค์— ์‚ฌ์šฉ๋  ์ด๋ฏธ์ง€ ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜ ํ•จ์ˆ˜๋ฅผ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค.
76
  def image_segmentation(input_img):
77
+ input_img = np.array(input_img) # ์ด๋ฏธ์ง€๋ฅผ ๋„˜ํŒŒ์ด ๋ฐฐ์—ด๋กœ ๋ณ€ํ™˜
 
78
 
79
  # ์ด๋ฏธ์ง€ ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜ ์ˆ˜ํ–‰
80
  inputs = feature_extractor(images=input_img, return_tensors="pt")
81
  outputs = model(**inputs)
82
  pred_label = np.argmax(outputs.logits[0].numpy(), axis=0)
83
+ pred_img = label_to_color_image(pred_label)
 
 
 
84
 
85
  return pred_img
86
 
87
+ def label_to_color_image(label):
88
+ if label.ndim != 2:
89
+ raise ValueError("Expect 2-D input label")
90
+
91
+ if np.max(label) >= len(colormap):
92
+ raise ValueError("label value too large.")
93
+ return colormap[label]
94
+
95
  # Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ƒ์„ฑ
96
+ iface = gr.Interface(
97
  fn=image_segmentation,
98
  inputs="image",
99
  outputs="image"
100
+ )
101
 
102
+ # ์ธํ„ฐํŽ˜์ด์Šค ์‹คํ–‰
103
+ iface.launch()
104