Spaces:
svjack
/
Runtime error

File size: 5,709 Bytes
d73ec66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
426e73b
 
 
d2b4796
ad7e7b4
426e73b
026a244
d4be87c
426e73b
 
 
d2b4796
ad7e7b4
d2b4796
426e73b
1f4c885
426e73b
 
 
26d69fc
426e73b
1f4c885
426e73b
 
d2b4796
426e73b
 
 
d2b4796
426e73b
 
 
d2b4796
426e73b
 
 
d2b4796
426e73b
 
 
 
 
d2b4796
 
 
 
 
 
 
 
 
 
 
 
426e73b
 
 
9617fed
426e73b
 
d2b4796
426e73b
 
6c303c0
426e73b
d2b4796
426e73b
6c303c0
426e73b
75076bc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
'''
https://huggingface.co/spaces/merve/OWLSAM

text,letter,watermark

vim run_text_mask.py

from gradio_client import Client, handle_file
from datasets import load_dataset, Image as HfImage
from PIL import ImageOps, Image
import numpy as np
import os
from tqdm import tqdm

# 初始化客户端
client = Client("http://localhost:7860")

# 加载数据集
dataset_name = "svjack/InfiniteYou_PosterCraft_Wang_Leehom_Poster_FP8_WAV"
dataset = load_dataset(dataset_name)

# 创建保存 mask 的文件夹
os.makedirs("mask_images", exist_ok=True)

#### 832, 1216
#### (864, 1152)
def combine_non_white_regions(annotations):
    canvas = None
    for i, annotation in enumerate(annotations):
        img = Image.open(annotation["image"]).convert("RGBA")
        img_array = np.array(img)
        if canvas is None:
            height, width = img_array.shape[:2]
            canvas = np.zeros((height, width, 4), dtype=np.uint8)
        rgb = img_array[..., :3]
        non_white_mask = np.any(rgb < 240, axis=-1, keepdims=True)
        alpha_layer = np.where(non_white_mask, img_array[..., 3:], 0)
        processed_img = np.concatenate([rgb, alpha_layer], axis=-1)
        canvas = np.where(processed_img[..., 3:] > 0, processed_img, canvas)
    if canvas is None:
        height = 1152
        width = 864
        result_array = np.zeros((height, width, 4), dtype=np.uint8)
        result_array[..., :3] = 255
        result_array[..., 3] = 255
        return Image.fromarray(result_array.astype(np.uint8))
        
    result_array = np.zeros((height, width, 4), dtype=np.uint8)
    result_array[..., :3] = 255
    result_array[..., 3] = 255
    result_array = np.where(canvas[..., 3:] > 0, canvas, result_array)

    non_white_mask = np.any(result_array[..., :3] < 255, axis=-1)
    result_array[non_white_mask] = [0, 0, 0, 255]

    return Image.fromarray(result_array.astype(np.uint8))

def generate_mask(image, idx):
    try:
        # 保存原始图片为临时文件
        temp_input_path = f"mask_images/temp_{idx:04d}.jpg"
        image.save(temp_input_path)

        # 调用 Gradio API
        result = client.predict(
            image=handle_file(temp_input_path),
            texts="text,letter,watermark",
            threshold=0.05,
            sam_threshold=0.88,
            api_name="/predict"
        )

        # 生成 mask 图像
        mask_image = combine_non_white_regions(result["annotations"])
        mask_image = ImageOps.invert(mask_image.convert("RGB"))

        # 保存 mask 图像
        output_mask_path = f"mask_images/mask_{idx:04d}.jpg"
        mask_image.save(output_mask_path)

        return {"mask_image": output_mask_path}

    except Exception as e:
        print(f"生成 mask 时出错 (index={idx}): {e}")
        return {"mask_image": None}

# 使用 map 处理整个数据集
updated_dataset = dataset["train"].map(
    lambda example, idx: generate_mask(example["Wang_Leehom_poster_image"], idx),
    with_indices=True,
    num_proc=1,
    batched=False
)

# 转换列类型为 Image
updated_dataset = updated_dataset.cast_column("mask_image", HfImage())

# 保存更新后的数据集
output_path = "Wang_Leehom_PosterCraft_with_Mask"
updated_dataset.save_to_disk(output_path)

print(f"✅ 已生成包含 mask 的数据集并保存至: {output_path}")
'''


from transformers import pipeline, SamModel, SamProcessor
import torch
import numpy as np
import gradio as gr
import spaces

checkpoint = "google/owlv2-base-patch16-ensemble"
detector = pipeline(model=checkpoint, task="zero-shot-object-detection", device="cuda")
sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to("cuda")
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")


@spaces.GPU
def query(image, texts, threshold, sam_threshold):
  texts = texts.split(",")

  predictions = detector(
    image,
    candidate_labels=texts,
    threshold=threshold
  )

  result_labels = []
  for pred in predictions:

    box = pred["box"]
    score = pred["score"]
    label = pred["label"]
    box = [round(pred["box"]["xmin"], 2), round(pred["box"]["ymin"], 2),
        round(pred["box"]["xmax"], 2), round(pred["box"]["ymax"], 2)]
    inputs = sam_processor(
            image,
            input_boxes=[[box]],
            return_tensors="pt"
        ).to("cuda")
    with torch.no_grad():
      outputs = sam_model(**inputs)

    mask = sam_processor.image_processor.post_process_masks(
        outputs.pred_masks.cpu(),
        inputs["original_sizes"].cpu(),
        inputs["reshaped_input_sizes"].cpu()
    )
    iou_scores = outputs["iou_scores"]
    
    masks, iou_scores, boxes = sam_processor.image_processor.filter_masks(
        mask[0],
        iou_scores[0].cpu(),
        inputs["original_sizes"][0].cpu(),
        box,
        pred_iou_thresh=sam_threshold,
    )
    
    result_labels.append((mask[0][0][0].numpy(), label))
  return image, result_labels


description = "This Space combines OWLv2, the state-of-the-art zero-shot object detection model with SAM, the state-of-the-art mask generation model. SAM normally doesn't accept text input. Combining SAM with OWLv2 makes SAM text promptable. Try the example or input an image and comma separated candidate labels to segment."
demo = gr.Interface(
    query,
    inputs=[gr.Image(type="pil", label="Image Input"), gr.Textbox(label = "Candidate Labels"), gr.Slider(0, 1, value=0.05, label="Confidence Threshold for OWL"), gr.Slider(0, 1, value=0.88, label="IoU threshold for SAM")],
    outputs="annotatedimage",
    title="OWL 🤝 SAM",
    description=description,
    examples=[
        ["./cats.png", "cat", 0.1, 0.88],
    ],
    cache_examples=True
)
demo.launch(debug=True, share = True)