Spaces:
svjack
/
Runtime error

svjack commited on
Commit
d73ec66
·
verified ·
1 Parent(s): 75076bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -0
app.py CHANGED
@@ -1,3 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from transformers import pipeline, SamModel, SamProcessor
2
  import torch
3
  import numpy as np
 
1
+ '''
2
+ https://huggingface.co/spaces/merve/OWLSAM
3
+
4
+ text,letter,watermark
5
+
6
+ vim run_text_mask.py
7
+
8
+ from gradio_client import Client, handle_file
9
+ from datasets import load_dataset, Image as HfImage
10
+ from PIL import ImageOps, Image
11
+ import numpy as np
12
+ import os
13
+ from tqdm import tqdm
14
+
15
+ # 初始化客户端
16
+ client = Client("http://localhost:7860")
17
+
18
+ # 加载数据集
19
+ dataset_name = "svjack/InfiniteYou_PosterCraft_Wang_Leehom_Poster_FP8_WAV"
20
+ dataset = load_dataset(dataset_name)
21
+
22
+ # 创建保存 mask 的文件夹
23
+ os.makedirs("mask_images", exist_ok=True)
24
+
25
+ #### 832, 1216
26
+ #### (864, 1152)
27
+ def combine_non_white_regions(annotations):
28
+ canvas = None
29
+ for i, annotation in enumerate(annotations):
30
+ img = Image.open(annotation["image"]).convert("RGBA")
31
+ img_array = np.array(img)
32
+ if canvas is None:
33
+ height, width = img_array.shape[:2]
34
+ canvas = np.zeros((height, width, 4), dtype=np.uint8)
35
+ rgb = img_array[..., :3]
36
+ non_white_mask = np.any(rgb < 240, axis=-1, keepdims=True)
37
+ alpha_layer = np.where(non_white_mask, img_array[..., 3:], 0)
38
+ processed_img = np.concatenate([rgb, alpha_layer], axis=-1)
39
+ canvas = np.where(processed_img[..., 3:] > 0, processed_img, canvas)
40
+ if canvas is None:
41
+ height = 1152
42
+ width = 864
43
+ result_array = np.zeros((height, width, 4), dtype=np.uint8)
44
+ result_array[..., :3] = 255
45
+ result_array[..., 3] = 255
46
+ return Image.fromarray(result_array.astype(np.uint8))
47
+
48
+ result_array = np.zeros((height, width, 4), dtype=np.uint8)
49
+ result_array[..., :3] = 255
50
+ result_array[..., 3] = 255
51
+ result_array = np.where(canvas[..., 3:] > 0, canvas, result_array)
52
+
53
+ non_white_mask = np.any(result_array[..., :3] < 255, axis=-1)
54
+ result_array[non_white_mask] = [0, 0, 0, 255]
55
+
56
+ return Image.fromarray(result_array.astype(np.uint8))
57
+
58
+ def generate_mask(image, idx):
59
+ try:
60
+ # 保存原始图片为临时文件
61
+ temp_input_path = f"mask_images/temp_{idx:04d}.jpg"
62
+ image.save(temp_input_path)
63
+
64
+ # 调用 Gradio API
65
+ result = client.predict(
66
+ image=handle_file(temp_input_path),
67
+ texts="text,letter,watermark",
68
+ threshold=0.05,
69
+ sam_threshold=0.88,
70
+ api_name="/predict"
71
+ )
72
+
73
+ # 生成 mask 图像
74
+ mask_image = combine_non_white_regions(result["annotations"])
75
+ mask_image = ImageOps.invert(mask_image.convert("RGB"))
76
+
77
+ # 保存 mask 图像
78
+ output_mask_path = f"mask_images/mask_{idx:04d}.jpg"
79
+ mask_image.save(output_mask_path)
80
+
81
+ return {"mask_image": output_mask_path}
82
+
83
+ except Exception as e:
84
+ print(f"生成 mask 时出错 (index={idx}): {e}")
85
+ return {"mask_image": None}
86
+
87
+ # 使用 map 处理整个数据集
88
+ updated_dataset = dataset["train"].map(
89
+ lambda example, idx: generate_mask(example["Wang_Leehom_poster_image"], idx),
90
+ with_indices=True,
91
+ num_proc=1,
92
+ batched=False
93
+ )
94
+
95
+ # 转换列类型为 Image
96
+ updated_dataset = updated_dataset.cast_column("mask_image", HfImage())
97
+
98
+ # 保存更新后的数据集
99
+ output_path = "Wang_Leehom_PosterCraft_with_Mask"
100
+ updated_dataset.save_to_disk(output_path)
101
+
102
+ print(f"✅ 已生成包含 mask 的数据集并保存至: {output_path}")
103
+ '''
104
+
105
+
106
  from transformers import pipeline, SamModel, SamProcessor
107
  import torch
108
  import numpy as np