Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,3 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
import spaces
|
3 |
import gradio as gr
|
|
|
1 |
+
'''
|
2 |
+
import os
|
3 |
+
from datasets import load_dataset
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
def save_images_locally(dataset_name="svjack/InfiniteYou_PosterCraft_Wang_Leehom_Poster_FP8_WAV_text_mask"):
|
7 |
+
# 加载数据集
|
8 |
+
dataset = load_dataset(dataset_name)["train"]
|
9 |
+
|
10 |
+
# 创建文件夹
|
11 |
+
image_dir = "original_images"
|
12 |
+
mask_dir = "mask_images"
|
13 |
+
os.makedirs(image_dir, exist_ok=True)
|
14 |
+
os.makedirs(mask_dir, exist_ok=True)
|
15 |
+
|
16 |
+
# 遍历数据集并保存图像
|
17 |
+
for idx, example in enumerate(dataset):
|
18 |
+
# 保存原图
|
19 |
+
image_path = os.path.join(image_dir, f"{idx:04d}.png")
|
20 |
+
example["Wang_Leehom_poster_image"].save(image_path)
|
21 |
+
|
22 |
+
# 保存 mask 图像
|
23 |
+
mask_path = os.path.join(mask_dir, f"{idx:04d}.png")
|
24 |
+
example["mask_image"].convert("RGB").save(mask_path)
|
25 |
+
|
26 |
+
print(f"✅ 原图已保存至 {image_dir}")
|
27 |
+
print(f"✅ Mask 图像已保存至 {mask_dir}")
|
28 |
+
|
29 |
+
if __name__ == "__main__":
|
30 |
+
save_images_locally()
|
31 |
+
|
32 |
+
import cv2
|
33 |
+
import numpy as np
|
34 |
+
from PIL import Image
|
35 |
+
|
36 |
+
def blur_mask_opencv(
|
37 |
+
mask: Image.Image,
|
38 |
+
blur_factor: int,
|
39 |
+
edge_expand: int = 0 # 新增参数:边缘扩展像素值
|
40 |
+
) -> Image.Image:
|
41 |
+
"""
|
42 |
+
增强版遮罩处理:支持边缘扩展 + 高斯模糊羽化
|
43 |
+
参数:
|
44 |
+
mask: PIL格式的遮罩图像(支持 RGBA/RGB/L 模式)
|
45 |
+
blur_factor: 模糊强度(奇数)
|
46 |
+
edge_expand: 边缘扩展像素值(非负整数)
|
47 |
+
"""
|
48 |
+
# 转换为OpenCV格式(NumPy数组)
|
49 |
+
mask_np = np.array(mask)
|
50 |
+
original_mode = mask.mode # 保存原始色彩模式
|
51 |
+
|
52 |
+
# 通道分离逻辑
|
53 |
+
if original_mode == 'RGBA':
|
54 |
+
alpha = mask_np[:, :, 3]
|
55 |
+
processed = alpha
|
56 |
+
elif original_mode == 'RGB':
|
57 |
+
processed = cv2.cvtColor(mask_np, cv2.COLOR_RGB2GRAY)
|
58 |
+
else: # L模式(灰度)
|
59 |
+
processed = mask_np
|
60 |
+
|
61 |
+
# 边缘扩展(核心新增功能)
|
62 |
+
if edge_expand > 0:
|
63 |
+
kernel = np.ones((3, 3), np.uint8) # 3x3方形结构核
|
64 |
+
processed = cv2.dilate(
|
65 |
+
processed,
|
66 |
+
kernel,
|
67 |
+
iterations=edge_expand # 扩展次数 = 扩展像素值
|
68 |
+
)
|
69 |
+
|
70 |
+
# 高斯模糊羽化
|
71 |
+
kernel_size = blur_factor | 1 # 确保为奇数
|
72 |
+
blurred = cv2.GaussianBlur(processed, (kernel_size, kernel_size), 0)
|
73 |
+
|
74 |
+
# 重建为PIL图像
|
75 |
+
if original_mode == 'RGBA':
|
76 |
+
result = mask_np.copy()
|
77 |
+
result[:, :, 3] = blurred # 仅替换Alpha通道
|
78 |
+
return Image.fromarray(result)
|
79 |
+
else:
|
80 |
+
return Image.fromarray(blurred)
|
81 |
+
|
82 |
+
|
83 |
+
#from PIL import Image
|
84 |
+
#mask = Image.open("image (51).jpg")
|
85 |
+
|
86 |
+
# 使用示例
|
87 |
+
#blurred_mask = blur_mask_opencv(mask, 33, 20)
|
88 |
+
#blurred_mask.save("b_mask.png")
|
89 |
+
#blurred_mask
|
90 |
+
|
91 |
+
|
92 |
+
import os
|
93 |
+
import numpy as np
|
94 |
+
from PIL import Image
|
95 |
+
import torch
|
96 |
+
from diffusers import FluxFillPipeline
|
97 |
+
|
98 |
+
# 初始化模型
|
99 |
+
pipe = FluxFillPipeline.from_pretrained("black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16)
|
100 |
+
pipe.enable_model_cpu_offload()
|
101 |
+
|
102 |
+
def calculate_optimal_dimensions(image):
|
103 |
+
original_width, original_height = image.size
|
104 |
+
MIN_ASPECT_RATIO = 9 / 16
|
105 |
+
MAX_ASPECT_RATIO = 16 / 9
|
106 |
+
FIXED_DIMENSION = 1024
|
107 |
+
original_aspect_ratio = original_width / original_height
|
108 |
+
|
109 |
+
if original_aspect_ratio > 1:
|
110 |
+
width = FIXED_DIMENSION
|
111 |
+
height = round(FIXED_DIMENSION / original_aspect_ratio)
|
112 |
+
else:
|
113 |
+
height = FIXED_DIMENSION
|
114 |
+
width = round(FIXED_DIMENSION * original_aspect_ratio)
|
115 |
+
|
116 |
+
width = (width // 8) * 8
|
117 |
+
height = (height // 8) * 8
|
118 |
+
calculated_aspect_ratio = width / height
|
119 |
+
|
120 |
+
if calculated_aspect_ratio > MAX_ASPECT_RATIO:
|
121 |
+
width = (height * MAX_ASPECT_RATIO // 8) * 8
|
122 |
+
elif calculated_aspect_ratio < MIN_ASPECT_RATIO:
|
123 |
+
height = (width / MIN_ASPECT_RATIO // 8) * 8
|
124 |
+
|
125 |
+
width = max(width, 576) if width == FIXED_DIMENSION else width
|
126 |
+
height = max(height, 576) if height == FIXED_DIMENSION else height
|
127 |
+
return width, height
|
128 |
+
|
129 |
+
def calculate_white_pixel_ratio(mask_image):
|
130 |
+
mask_array = np.array(mask_image.convert('L'))
|
131 |
+
white_pixels = np.sum(mask_array == 255)
|
132 |
+
total_pixels = mask_array.size
|
133 |
+
return white_pixels / total_pixels
|
134 |
+
|
135 |
+
def inpaint(image_path, mask_path, output_path):
|
136 |
+
image = Image.open(image_path).convert("RGB")
|
137 |
+
mask = Image.open(mask_path).convert("L")
|
138 |
+
#mask = blur_mask_opencv(mask, 33, 20)
|
139 |
+
mask = blur_mask_opencv(mask, 33, 60)
|
140 |
+
|
141 |
+
white_ratio = calculate_white_pixel_ratio(mask)
|
142 |
+
if white_ratio >= 1 / 3:
|
143 |
+
image.save(output_path)
|
144 |
+
return
|
145 |
+
|
146 |
+
width, height = calculate_optimal_dimensions(image)
|
147 |
+
result = pipe(
|
148 |
+
prompt="",
|
149 |
+
height=height,
|
150 |
+
width=width,
|
151 |
+
image=image,
|
152 |
+
mask_image=mask,
|
153 |
+
num_inference_steps=40,
|
154 |
+
guidance_scale=28,
|
155 |
+
).images[0]
|
156 |
+
result.save(output_path)
|
157 |
+
|
158 |
+
def process_all_image_pairs():
|
159 |
+
image_dir = "original_images"
|
160 |
+
mask_dir = "mask_images"
|
161 |
+
output_dir = "inpaint_results"
|
162 |
+
os.makedirs(output_dir, exist_ok=True)
|
163 |
+
|
164 |
+
for filename in os.listdir(image_dir):
|
165 |
+
base_name = os.path.splitext(filename)[0]
|
166 |
+
image_path = os.path.join(image_dir, filename)
|
167 |
+
mask_path = os.path.join(mask_dir, filename)
|
168 |
+
output_path = os.path.join(output_dir, filename)
|
169 |
+
|
170 |
+
if not os.path.exists(mask_path):
|
171 |
+
continue
|
172 |
+
|
173 |
+
try:
|
174 |
+
inpaint(image_path, mask_path, output_path)
|
175 |
+
except Exception as e:
|
176 |
+
print(f"❌ 处理 {filename} 时出错: {e}")
|
177 |
+
|
178 |
+
print(f"✅ 修复结果已保存至 {output_dir}")
|
179 |
+
|
180 |
+
if __name__ == "__main__":
|
181 |
+
process_all_image_pairs()
|
182 |
+
|
183 |
+
import os
|
184 |
+
from datasets import load_dataset, DatasetDict, Image as HfImage
|
185 |
+
from PIL import Image
|
186 |
+
|
187 |
+
def update_dataset_with_inpaint_results():
|
188 |
+
dataset_name = "svjack/InfiniteYou_PosterCraft_Wang_Leehom_Poster_FP8_WAV_text_mask"
|
189 |
+
input_dataset = load_dataset(dataset_name)["train"]
|
190 |
+
|
191 |
+
output_dir = "inpaint_results"
|
192 |
+
output_dataset_path = "InfiniteYou_PosterCraft_Wang_Leehom_Poster_FP8_WAV_text_mask_inpaint"
|
193 |
+
os.makedirs(output_dataset_path, exist_ok=True)
|
194 |
+
|
195 |
+
def add_inpaint_image(example, idx):
|
196 |
+
output_path = os.path.join(output_dir, f"{idx:04d}.png")
|
197 |
+
|
198 |
+
if os.path.exists(output_path):
|
199 |
+
example["inpaint_image"] = output_path
|
200 |
+
else:
|
201 |
+
# 如果没有生成修复图像,则使用原始图像
|
202 |
+
example["inpaint_image"] = example["Wang_Leehom_poster_image"]
|
203 |
+
|
204 |
+
return example
|
205 |
+
|
206 |
+
updated_dataset = input_dataset.map(
|
207 |
+
lambda ex, idx: add_inpaint_image(ex, idx),
|
208 |
+
with_indices=True,
|
209 |
+
batched=False,
|
210 |
+
num_proc=1
|
211 |
+
)
|
212 |
+
|
213 |
+
updated_dataset = updated_dataset.cast_column("inpaint_image", HfImage())
|
214 |
+
updated_dataset.save_to_disk(output_dataset_path)
|
215 |
+
|
216 |
+
print(f"✅ 更新后的数据集已保存至 {output_dataset_path}")
|
217 |
+
|
218 |
+
if __name__ == "__main__":
|
219 |
+
update_dataset_with_inpaint_results()
|
220 |
+
'''
|
221 |
+
|
222 |
import torch
|
223 |
import spaces
|
224 |
import gradio as gr
|