File size: 7,960 Bytes
96db9b0
 
 
9c08ff8
 
 
96db9b0
 
b6bd523
f916288
52cb57a
19860f0
 
 
96db9b0
 
 
 
 
 
52cb57a
19860f0
 
 
52cb57a
9c08ff8
 
19860f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c08ff8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96db9b0
 
9c08ff8
96db9b0
 
9c08ff8
96db9b0
 
 
9c08ff8
96db9b0
 
d3eede8
96db9b0
 
 
 
d3eede8
 
9c08ff8
d3eede8
 
 
96db9b0
d3eede8
96db9b0
 
 
 
 
9c08ff8
96db9b0
 
 
 
f916288
96db9b0
b6bd523
96db9b0
 
 
9c08ff8
96db9b0
 
f916288
96db9b0
9c08ff8
d3eede8
9c08ff8
 
 
 
 
96db9b0
9c08ff8
 
96db9b0
 
 
 
 
 
 
 
 
9c08ff8
 
96db9b0
 
 
9c08ff8
 
96db9b0
 
9c08ff8
96db9b0
9c08ff8
 
0b350f9
9c08ff8
 
 
 
 
96db9b0
9c08ff8
0b350f9
 
9c08ff8
 
96db9b0
9c08ff8
96db9b0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import gradio as gr
import os
from PIL import Image, ImageChops, ImageFilter
from ultralytics import YOLO
from segment_anything import SamPredictor, sam_model_registry
from transformers import BlipProcessor, BlipForConditionalGeneration, CLIPProcessor, CLIPModel, AutoProcessor, AutoModelForImageClassification
import torch
import matplotlib.pyplot as plt
import numpy as np
from openai import OpenAI
from huggingface_hub import hf_hub_download
from segment_anything import SamPredictor, sam_model_registry
from yolo_world.models.detectors import build_detector
from mmcv import Config

# 初始化模型
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
sam_checkpoint = hf_hub_download(
    repo_id="facebook/sam-vit-large",  # 仓库 ID
    filename="model.safetensors",      # 模型文件名
    use_auth_token=False              # 公共仓库无需身份验证
)
sam = sam_model_registry["vit_h"](checkpoint=sam_checkpoint)
sam_predictor = SamPredictor(sam)
# 从 Hugging Face 下载 YOLO-World 权重
yolo_checkpoint = hf_hub_download(
    repo_id="stevengrove/YOLO-World",  # Hugging Face 仓库 ID
    filename="yolo_world_v2_xl_obj365v1_goldg_cc3mlite_pretrain.pth",  # 模型权重文件名
    use_auth_token=False  # 公共仓库无需身份验证
)
# 加载 YOLO-World 配置文件
yolo_config = Config.fromfile('path/to/yolo_world_config.py')  # 替换为实际配置文件路径
# 构建 YOLO-World 模型
yolo_model = build_detector(yolo_config.model)
# 加载权重到模型
checkpoint = torch.load(yolo_checkpoint, map_location="cpu")  # 使用 CPU 加载权重,后续可以转移到 GPU
yolo_model.load_state_dict(checkpoint["state_dict"])
yolo_model.eval()  # 设置为评估模式

wd_processor = AutoProcessor.from_pretrained("SmilingWolf/wd-vit-tagger-v3")
wd_model = AutoModelForImageClassification.from_pretrained("SmilingWolf/wd-vit-tagger-v3")

# 自动识别图片类型
def classify_image_type(image):
    inputs = wd_processor(images=image, return_tensors="pt")
    outputs = wd_model(**inputs)
    scores = torch.softmax(outputs.logits, dim=1)[0]
    anime_score = scores[wd_processor.label2id["anime"]].item()
    return "anime" if anime_score > 0.5 else "real"

# 分割图像对象
def segment_objects(image, boxes):
    image_np = np.array(image)
    sam_predictor.set_image(image_np)
    masks = []
    for box in boxes:
        mask, _, _ = sam_predictor.predict(
            point_coords=None, point_labels=None, box=box, multimask_output=False
        )
        masks.append(mask)
    return masks

# 检测对象
def detect_objects(image, image_type):
    if image_type == "real":
        results = yolo_model.predict(np.array(image), conf=0.25)
        objects = [{"label": r["class"], "box": r["bbox"], "confidence": r["confidence"]} for r in results]
    else:
        inputs = wd_processor(images=image, return_tensors="pt")
        outputs = wd_model(**inputs)
        scores = torch.softmax(outputs.logits, dim=1)[0]
        top_k = torch.topk(scores, k=5)
        objects = [{"label": wd_processor.decode(top_k.indices[i].item()), "confidence": top_k.values[i].item()} for i in range(5)]
    return objects

# 生成语义描述
def generate_object_descriptions(image, objects):
    descriptions = []
    for obj in objects:
        box = obj.get("box", None)
        if box:
            cropped = image.crop(box)
        else:
            cropped = image
        inputs = blip_processor(cropped, return_tensors="pt")
        caption = blip_model.generate(**inputs, max_length=128, num_beams=5, no_repeat_ngram_size=2)
        description = blip_processor.decode(caption[0], skip_special_tokens=True)
        descriptions.append({"label": obj["label"], "description": description})
    return descriptions

# 特征差异可视化
def plot_feature_differences(latent_diff, descriptions, prefix):
    diff_magnitude = [abs(x) for x in latent_diff[0]]
    indices = range(len(diff_magnitude))
    top_indices = np.argsort(diff_magnitude)[-10:][::-1]

    plt.figure(figsize=(8, 4))
    plt.bar(indices, diff_magnitude, alpha=0.7)
    plt.xlabel("Feature Index")
    plt.ylabel("Magnitude of Difference")
    plt.title("Feature Differences (Bar Chart)")
    bar_chart_path = f"{prefix}_bar_chart.png"
    plt.savefig(bar_chart_path)
    plt.close()

    plt.figure(figsize=(6, 6))
    plt.pie(
        [diff_magnitude[i] for i in top_indices],
        labels=[descriptions[i] for i in top_indices],
        autopct="%1.1f%%",
        startangle=140
    )
    plt.title("Top 10 Feature Differences (Pie Chart)")
    pie_chart_path = f"{prefix}_pie_chart.png"
    plt.savefig(pie_chart_path)
    plt.close()

    return bar_chart_path, pie_chart_path

# 生成详细分析文本
def generate_text_analysis(api_key, api_type, caption_a, caption_b):
    if api_type == "DeepSeek":
        client = OpenAI(api_key=api_key, base_url="https://api.deepseek.com")
    else:
        client = OpenAI(api_key=api_key)

    response = client.chat.completions.create(
        model="gpt-4" if api_type == "GPT" else "deepseek-chat",
        messages=[
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": f"图片A的描述为:{caption_a}。\n图片B的描述为:{caption_b}。\n请对两张图片进行详细对比分析。"}
        ]
    )
    return response.choices[0].message.content.strip()

# 分析单对图片
def analyze_images(img_a, img_b, api_key, api_type, prefix):
    type_a = classify_image_type(img_a)
    type_b = classify_image_type(img_b)

    objects_a = detect_objects(img_a, type_a)
    objects_b = detect_objects(img_b, type_b)

    descriptions_a = generate_object_descriptions(img_a, objects_a)
    descriptions_b = generate_object_descriptions(img_b, objects_b)

    inputs = clip_processor(images=img_a, return_tensors="pt")
    features_a = clip_model.get_image_features(**inputs).detach().numpy()

    inputs = clip_processor(images=img_b, return_tensors="pt")
    features_b = clip_model.get_image_features(**inputs).detach().numpy()

    latent_diff = np.abs(features_a - features_b).tolist()

    bar_chart, pie_chart = plot_feature_differences(latent_diff, [d['label'] for d in descriptions_a], prefix)
    text_analysis = generate_text_analysis(api_key, api_type, descriptions_a, descriptions_b)

    return {
        "bar_chart": bar_chart,
        "pie_chart": pie_chart,
        "text_analysis": text_analysis
    }

# Gradio 界面
with gr.Blocks() as demo:
    gr.Markdown("# 综合图像对比分析工具")
    api_key_input = gr.Textbox(label="API Key", placeholder="输入 API Key", type="password")
    api_type_input = gr.Radio(label="API 类型", choices=["GPT", "DeepSeek"], value="GPT")
    images_a_input = gr.File(label="上传文件夹A图片", file_types=[".png", ".jpg"], file_count="multiple")
    images_b_input = gr.File(label="上传文件夹B图片", file_types=[".png", ".jpg"], file_count="multiple")
    analyze_button = gr.Button("开始分析")
    result_gallery = gr.Gallery(label="差异可视化")
    result_text = gr.Textbox(label="分析结果", lines=5)

    def process_batch(images_a, images_b, api_key, api_type):
        images_a = [Image.open(img).convert("RGB") for img in images_a]
        images_b = [Image.open(img).convert("RGB") for img in images_b]
        results = [analyze_images(img_a, img_b, api_key, api_type, f"comparison_{i+1}") for i, (img_a, img_b) in enumerate(zip(images_a, images_b))]
        return results

    analyze_button.click(process_batch, inputs=[images_a_input, images_b_input, api_key_input, api_type_input], outputs=[result_gallery, result_text])

demo.launch()