EUNSEO56 commited on
Commit
310cd05
ยท
1 Parent(s): 8004aa4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -4
app.py CHANGED
@@ -71,17 +71,24 @@ def label_to_color_image(label):
71
  if np.max(label) >= len(colormap):
72
  raise ValueError("label value too large.")
73
  return colormap[label]
74
-
75
- # Create a Gradio interface
76
  def image_segmentation(input_img):
77
- # Perform image segmentation
 
 
 
78
  inputs = feature_extractor(images=input_img, return_tensors="pt")
79
  outputs = model(**inputs)
80
  pred_label = np.argmax(outputs.logits[0].numpy(), axis=0)
81
- pred_img = label_to_color_image(pred_label)
 
 
 
82
 
83
  return pred_img
84
 
 
85
  gr.Interface(
86
  fn=image_segmentation,
87
  inputs="image",
 
71
  if np.max(label) >= len(colormap):
72
  raise ValueError("label value too large.")
73
  return colormap[label]
74
+
75
+ # Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ƒ์„ฑ
76
  def image_segmentation(input_img):
77
+ # ์ด๋ฏธ์ง€๋ฅผ ๋„˜ํŒŒ์ด ๋ฐฐ์—ด๋กœ ๋ณ€ํ™˜
78
+ input_img = np.array(input_img)
79
+
80
+ # ์ด๋ฏธ์ง€ ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜ ์ˆ˜ํ–‰
81
  inputs = feature_extractor(images=input_img, return_tensors="pt")
82
  outputs = model(**inputs)
83
  pred_label = np.argmax(outputs.logits[0].numpy(), axis=0)
84
+
85
+ # ์˜ˆ์ธก๋œ ๋ ˆ์ด๋ธ”์„ ์ด๋ฏธ์ง€๋กœ ๋ณ€ํ™˜
86
+ pred_img = Image.fromarray(pred_label.astype(np.uint8), 'P')
87
+ pred_img.putpalette(ade_palette()) # ํŒ”๋ ˆํŠธ ์„ค์ •
88
 
89
  return pred_img
90
 
91
+ # Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ƒ์„ฑ
92
  gr.Interface(
93
  fn=image_segmentation,
94
  inputs="image",