svjack commited on
Commit
aa24a87
·
verified ·
1 Parent(s): 8beb4c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +221 -0
app.py CHANGED
@@ -1,3 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import spaces
3
  import gradio as gr
 
1
+ '''
2
+ import os
3
+ from datasets import load_dataset
4
+ from PIL import Image
5
+
6
+ def save_images_locally(dataset_name="svjack/InfiniteYou_PosterCraft_Wang_Leehom_Poster_FP8_WAV_text_mask"):
7
+ # 加载数据集
8
+ dataset = load_dataset(dataset_name)["train"]
9
+
10
+ # 创建文件夹
11
+ image_dir = "original_images"
12
+ mask_dir = "mask_images"
13
+ os.makedirs(image_dir, exist_ok=True)
14
+ os.makedirs(mask_dir, exist_ok=True)
15
+
16
+ # 遍历数据集并保存图像
17
+ for idx, example in enumerate(dataset):
18
+ # 保存原图
19
+ image_path = os.path.join(image_dir, f"{idx:04d}.png")
20
+ example["Wang_Leehom_poster_image"].save(image_path)
21
+
22
+ # 保存 mask 图像
23
+ mask_path = os.path.join(mask_dir, f"{idx:04d}.png")
24
+ example["mask_image"].convert("RGB").save(mask_path)
25
+
26
+ print(f"✅ 原图已保存至 {image_dir}")
27
+ print(f"✅ Mask 图像已保存至 {mask_dir}")
28
+
29
+ if __name__ == "__main__":
30
+ save_images_locally()
31
+
32
+ import cv2
33
+ import numpy as np
34
+ from PIL import Image
35
+
36
+ def blur_mask_opencv(
37
+ mask: Image.Image,
38
+ blur_factor: int,
39
+ edge_expand: int = 0 # 新增参数:边缘扩展像素值
40
+ ) -> Image.Image:
41
+ """
42
+ 增强版遮罩处理:支持边缘扩展 + 高斯模糊羽化
43
+ 参数:
44
+ mask: PIL格式的遮罩图像(支持 RGBA/RGB/L 模式)
45
+ blur_factor: 模糊强度(奇数)
46
+ edge_expand: 边缘扩展像素值(非负整数)
47
+ """
48
+ # 转换为OpenCV格式(NumPy数组)
49
+ mask_np = np.array(mask)
50
+ original_mode = mask.mode # 保存原始色彩模式
51
+
52
+ # 通道分离逻辑
53
+ if original_mode == 'RGBA':
54
+ alpha = mask_np[:, :, 3]
55
+ processed = alpha
56
+ elif original_mode == 'RGB':
57
+ processed = cv2.cvtColor(mask_np, cv2.COLOR_RGB2GRAY)
58
+ else: # L模式(灰度)
59
+ processed = mask_np
60
+
61
+ # 边缘扩展(核心新增功能)
62
+ if edge_expand > 0:
63
+ kernel = np.ones((3, 3), np.uint8) # 3x3方形结构核
64
+ processed = cv2.dilate(
65
+ processed,
66
+ kernel,
67
+ iterations=edge_expand # 扩展次数 = 扩展像素值
68
+ )
69
+
70
+ # 高斯模糊羽化
71
+ kernel_size = blur_factor | 1 # 确保为奇数
72
+ blurred = cv2.GaussianBlur(processed, (kernel_size, kernel_size), 0)
73
+
74
+ # 重建为PIL图像
75
+ if original_mode == 'RGBA':
76
+ result = mask_np.copy()
77
+ result[:, :, 3] = blurred # 仅替换Alpha通道
78
+ return Image.fromarray(result)
79
+ else:
80
+ return Image.fromarray(blurred)
81
+
82
+
83
+ #from PIL import Image
84
+ #mask = Image.open("image (51).jpg")
85
+
86
+ # 使用示例
87
+ #blurred_mask = blur_mask_opencv(mask, 33, 20)
88
+ #blurred_mask.save("b_mask.png")
89
+ #blurred_mask
90
+
91
+
92
+ import os
93
+ import numpy as np
94
+ from PIL import Image
95
+ import torch
96
+ from diffusers import FluxFillPipeline
97
+
98
+ # 初始化模型
99
+ pipe = FluxFillPipeline.from_pretrained("black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16)
100
+ pipe.enable_model_cpu_offload()
101
+
102
+ def calculate_optimal_dimensions(image):
103
+ original_width, original_height = image.size
104
+ MIN_ASPECT_RATIO = 9 / 16
105
+ MAX_ASPECT_RATIO = 16 / 9
106
+ FIXED_DIMENSION = 1024
107
+ original_aspect_ratio = original_width / original_height
108
+
109
+ if original_aspect_ratio > 1:
110
+ width = FIXED_DIMENSION
111
+ height = round(FIXED_DIMENSION / original_aspect_ratio)
112
+ else:
113
+ height = FIXED_DIMENSION
114
+ width = round(FIXED_DIMENSION * original_aspect_ratio)
115
+
116
+ width = (width // 8) * 8
117
+ height = (height // 8) * 8
118
+ calculated_aspect_ratio = width / height
119
+
120
+ if calculated_aspect_ratio > MAX_ASPECT_RATIO:
121
+ width = (height * MAX_ASPECT_RATIO // 8) * 8
122
+ elif calculated_aspect_ratio < MIN_ASPECT_RATIO:
123
+ height = (width / MIN_ASPECT_RATIO // 8) * 8
124
+
125
+ width = max(width, 576) if width == FIXED_DIMENSION else width
126
+ height = max(height, 576) if height == FIXED_DIMENSION else height
127
+ return width, height
128
+
129
+ def calculate_white_pixel_ratio(mask_image):
130
+ mask_array = np.array(mask_image.convert('L'))
131
+ white_pixels = np.sum(mask_array == 255)
132
+ total_pixels = mask_array.size
133
+ return white_pixels / total_pixels
134
+
135
+ def inpaint(image_path, mask_path, output_path):
136
+ image = Image.open(image_path).convert("RGB")
137
+ mask = Image.open(mask_path).convert("L")
138
+ #mask = blur_mask_opencv(mask, 33, 20)
139
+ mask = blur_mask_opencv(mask, 33, 60)
140
+
141
+ white_ratio = calculate_white_pixel_ratio(mask)
142
+ if white_ratio >= 1 / 3:
143
+ image.save(output_path)
144
+ return
145
+
146
+ width, height = calculate_optimal_dimensions(image)
147
+ result = pipe(
148
+ prompt="",
149
+ height=height,
150
+ width=width,
151
+ image=image,
152
+ mask_image=mask,
153
+ num_inference_steps=40,
154
+ guidance_scale=28,
155
+ ).images[0]
156
+ result.save(output_path)
157
+
158
+ def process_all_image_pairs():
159
+ image_dir = "original_images"
160
+ mask_dir = "mask_images"
161
+ output_dir = "inpaint_results"
162
+ os.makedirs(output_dir, exist_ok=True)
163
+
164
+ for filename in os.listdir(image_dir):
165
+ base_name = os.path.splitext(filename)[0]
166
+ image_path = os.path.join(image_dir, filename)
167
+ mask_path = os.path.join(mask_dir, filename)
168
+ output_path = os.path.join(output_dir, filename)
169
+
170
+ if not os.path.exists(mask_path):
171
+ continue
172
+
173
+ try:
174
+ inpaint(image_path, mask_path, output_path)
175
+ except Exception as e:
176
+ print(f"❌ 处理 {filename} 时出错: {e}")
177
+
178
+ print(f"✅ 修复结果已保存至 {output_dir}")
179
+
180
+ if __name__ == "__main__":
181
+ process_all_image_pairs()
182
+
183
+ import os
184
+ from datasets import load_dataset, DatasetDict, Image as HfImage
185
+ from PIL import Image
186
+
187
+ def update_dataset_with_inpaint_results():
188
+ dataset_name = "svjack/InfiniteYou_PosterCraft_Wang_Leehom_Poster_FP8_WAV_text_mask"
189
+ input_dataset = load_dataset(dataset_name)["train"]
190
+
191
+ output_dir = "inpaint_results"
192
+ output_dataset_path = "InfiniteYou_PosterCraft_Wang_Leehom_Poster_FP8_WAV_text_mask_inpaint"
193
+ os.makedirs(output_dataset_path, exist_ok=True)
194
+
195
+ def add_inpaint_image(example, idx):
196
+ output_path = os.path.join(output_dir, f"{idx:04d}.png")
197
+
198
+ if os.path.exists(output_path):
199
+ example["inpaint_image"] = output_path
200
+ else:
201
+ # 如果没有生成修复图像,则使用原始图像
202
+ example["inpaint_image"] = example["Wang_Leehom_poster_image"]
203
+
204
+ return example
205
+
206
+ updated_dataset = input_dataset.map(
207
+ lambda ex, idx: add_inpaint_image(ex, idx),
208
+ with_indices=True,
209
+ batched=False,
210
+ num_proc=1
211
+ )
212
+
213
+ updated_dataset = updated_dataset.cast_column("inpaint_image", HfImage())
214
+ updated_dataset.save_to_disk(output_dataset_path)
215
+
216
+ print(f"✅ 更新后的数据集已保存至 {output_dataset_path}")
217
+
218
+ if __name__ == "__main__":
219
+ update_dataset_with_inpaint_results()
220
+ '''
221
+
222
  import torch
223
  import spaces
224
  import gradio as gr