zhiweili commited on
Commit
6bdded7
·
1 Parent(s): 0ec7070

add croper

Browse files
Files changed (3) hide show
  1. .gitignore +2 -1
  2. app.py +4 -49
  3. croper.py +71 -0
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  .vscode
2
- .DS_Store
 
 
1
  .vscode
2
+ .DS_Store
3
+ __pycache__
app.py CHANGED
@@ -1,10 +1,12 @@
1
  import gradio as gr
2
  import mediapipe as mp
3
  import numpy as np
 
4
  from PIL import Image
5
  from mediapipe.tasks import python
6
  from mediapipe.tasks.python import vision
7
  from scipy.ndimage import binary_dilation, label
 
8
 
9
  BG_COLOR = np.array([0, 0, 0], dtype=np.uint8) # black
10
  MASK_COLOR = np.array([255, 255, 255], dtype=np.uint8) # white
@@ -15,10 +17,8 @@ base_options = python.BaseOptions(model_asset_path=MODEL_PATH)
15
  options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True)
16
  segmenter = vision.ImageSegmenter.create_from_options(options)
17
  labels = segmenter.labels
18
- expand_size = 40
19
 
20
  def segment(input_image, category):
21
- original_height, original_width = input_image.size
22
  image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image))
23
  segmentation_result = segmenter.segment(image)
24
  category_mask = segmentation_result.category_mask
@@ -30,54 +30,9 @@ def segment(input_image, category):
30
  target_mask = get_clothes_mask(category_mask_np)
31
  else:
32
  target_mask = category_mask_np == 0
 
33
 
34
- target_indices = np.where(target_mask)
35
- start_y = np.min(target_indices[0]) - expand_size
36
- if start_y < 0:
37
- start_y = 0
38
- end_y = np.max(target_indices[0]) + expand_size
39
- if end_y > original_height:
40
- end_y = original_height
41
- start_x = np.min(target_indices[1]) - expand_size
42
- if start_x < 0:
43
- start_x = 0
44
- end_x = np.max(target_indices[1]) + expand_size
45
- if end_x > original_width:
46
- end_x = original_width
47
- target_height = end_y - start_y
48
- target_width = end_x - start_x
49
-
50
- # choose the max side length
51
- max_side_length = max(target_height, target_width)
52
- # calculate the crop area
53
- crop_mask = target_mask[start_y:end_y, start_x:end_x]
54
- crop_mask_height, crop_mask_width = crop_mask.shape
55
- crop_mask_start_y = (max_side_length - crop_mask_height) // 2
56
- crop_mask_end_y = crop_mask_start_y + crop_mask_height
57
- crop_mask_start_x = (max_side_length - crop_mask_width) // 2
58
- crop_mask_end_x = crop_mask_start_x + crop_mask_width
59
- # create a square mask
60
- crop_mask_square = np.zeros((max_side_length, max_side_length), dtype=target_mask.dtype)
61
- crop_mask_square[crop_mask_start_y:crop_mask_end_y, crop_mask_start_x:crop_mask_end_x] = crop_mask
62
- # create a square image
63
- crop_mask_image = Image.fromarray((crop_mask_square * 255).astype(np.uint8))
64
-
65
- crop_image = input_image.crop((start_x, start_y, end_x, end_y))
66
- crop_image_square = Image.new("RGB", (max_side_length, max_side_length))
67
- crop_image_square.paste(crop_image, (crop_mask_start_x, crop_mask_start_y))
68
-
69
- # Generate solid color images for showing the output segmentation mask.
70
- image_data = image.numpy_view()
71
- fg_image = np.zeros(image_data.shape, dtype=np.uint8)
72
- fg_image[:] = MASK_COLOR
73
- bg_image = np.zeros(image_data.shape, dtype=np.uint8)
74
- bg_image[:] = BG_COLOR
75
-
76
- condition = np.stack((target_mask,) * 3, axis=-1) > 0.2
77
-
78
- output_image = np.where(condition, fg_image, bg_image)
79
- output_image = Image.fromarray(output_image)
80
- return crop_mask_image, crop_image_square
81
 
82
  def get_clothes_mask(category_mask_np):
83
  body_skin_mask = category_mask_np == 2
 
1
  import gradio as gr
2
  import mediapipe as mp
3
  import numpy as np
4
+
5
  from PIL import Image
6
  from mediapipe.tasks import python
7
  from mediapipe.tasks.python import vision
8
  from scipy.ndimage import binary_dilation, label
9
+ from croper import Croper
10
 
11
  BG_COLOR = np.array([0, 0, 0], dtype=np.uint8) # black
12
  MASK_COLOR = np.array([255, 255, 255], dtype=np.uint8) # white
 
17
  options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True)
18
  segmenter = vision.ImageSegmenter.create_from_options(options)
19
  labels = segmenter.labels
 
20
 
21
  def segment(input_image, category):
 
22
  image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image))
23
  segmentation_result = segmenter.segment(image)
24
  category_mask = segmentation_result.category_mask
 
30
  target_mask = get_clothes_mask(category_mask_np)
31
  else:
32
  target_mask = category_mask_np == 0
33
+ croper = Croper(input_image, target_mask)
34
 
35
+ return croper.corp_mask_image()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  def get_clothes_mask(category_mask_np):
38
  body_skin_mask = category_mask_np == 2
croper.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL
2
+ import numpy as np
3
+
4
+ from PIL import Image
5
+
6
+ class Croper:
7
+ def __init__(
8
+ self,
9
+ input_image: PIL.Image,
10
+ target_mask: np.ndarray,
11
+ ):
12
+ self.input_image = input_image
13
+ self.target_mask = target_mask
14
+
15
+ def corp_mask_image(self):
16
+ target_mask = self.target_mask
17
+ input_image = self.input_image
18
+ crop_length = 512
19
+ expand_size = 40
20
+ original_width, original_height = input_image.size
21
+ mask_indices = np.where(target_mask)
22
+ start_y = np.min(mask_indices[0]) - expand_size
23
+ if start_y < 0:
24
+ start_y = 0
25
+ end_y = np.max(mask_indices[0]) + expand_size
26
+ if end_y > original_height:
27
+ end_y = original_height
28
+ start_x = np.min(mask_indices[1]) - expand_size
29
+ if start_x < 0:
30
+ start_x = 0
31
+ end_x = np.max(mask_indices[1]) + expand_size
32
+ if end_x > original_width:
33
+ end_x = original_width
34
+ mask_height = end_y - start_y
35
+ mask_width = end_x - start_x
36
+
37
+ # choose the max side length
38
+ max_side_length = max(mask_height, mask_width)
39
+ # calculate the crop area
40
+ crop_mask = target_mask[start_y:end_y, start_x:end_x]
41
+ crop_mask_start_y = (max_side_length - mask_height) // 2
42
+ crop_mask_end_y = crop_mask_start_y + mask_height
43
+ crop_mask_start_x = (max_side_length - mask_width) // 2
44
+ crop_mask_end_x = crop_mask_start_x + mask_width
45
+ # create a square mask
46
+ square_mask = np.zeros((max_side_length, max_side_length), dtype=target_mask.dtype)
47
+ square_mask[crop_mask_start_y:crop_mask_end_y, crop_mask_start_x:crop_mask_end_x] = crop_mask
48
+ square_mask_image = Image.fromarray((square_mask * 255).astype(np.uint8))
49
+
50
+ crop_image = input_image.crop((start_x, start_y, end_x, end_y))
51
+ square_image = Image.new("RGB", (max_side_length, max_side_length))
52
+ square_image.paste(crop_image, (crop_mask_start_x, crop_mask_start_y))
53
+
54
+ self.origin_start_x = start_x
55
+ self.origin_start_y = start_y
56
+ self.origin_end_x = end_x
57
+ self.origin_end_y = end_y
58
+
59
+ self.square_start_x = crop_mask_start_x
60
+ self.square_start_y = crop_mask_start_y
61
+ self.square_end_x = crop_mask_end_x
62
+ self.square_end_y = crop_mask_end_y
63
+
64
+ self.square_length = max_side_length
65
+ self.square_mask_image = square_mask_image
66
+ self.square_image = square_image
67
+
68
+ self.resized_square_mask_image = square_mask_image.resize((crop_length, crop_length))
69
+ self.resized_square_image = square_image.resize((crop_length, crop_length))
70
+
71
+ return self.square_image, self.resized_square_image