Spaces:
Sleeping
Sleeping
File size: 7,960 Bytes
96db9b0 9c08ff8 96db9b0 b6bd523 f916288 52cb57a 19860f0 96db9b0 52cb57a 19860f0 52cb57a 9c08ff8 19860f0 9c08ff8 96db9b0 9c08ff8 96db9b0 9c08ff8 96db9b0 9c08ff8 96db9b0 d3eede8 96db9b0 d3eede8 9c08ff8 d3eede8 96db9b0 d3eede8 96db9b0 9c08ff8 96db9b0 f916288 96db9b0 b6bd523 96db9b0 9c08ff8 96db9b0 f916288 96db9b0 9c08ff8 d3eede8 9c08ff8 96db9b0 9c08ff8 96db9b0 9c08ff8 96db9b0 9c08ff8 96db9b0 9c08ff8 96db9b0 9c08ff8 0b350f9 9c08ff8 96db9b0 9c08ff8 0b350f9 9c08ff8 96db9b0 9c08ff8 96db9b0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
import gradio as gr
import os
from PIL import Image, ImageChops, ImageFilter
from ultralytics import YOLO
from segment_anything import SamPredictor, sam_model_registry
from transformers import BlipProcessor, BlipForConditionalGeneration, CLIPProcessor, CLIPModel, AutoProcessor, AutoModelForImageClassification
import torch
import matplotlib.pyplot as plt
import numpy as np
from openai import OpenAI
from huggingface_hub import hf_hub_download
from segment_anything import SamPredictor, sam_model_registry
from yolo_world.models.detectors import build_detector
from mmcv import Config
# 初始化模型
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
sam_checkpoint = hf_hub_download(
repo_id="facebook/sam-vit-large", # 仓库 ID
filename="model.safetensors", # 模型文件名
use_auth_token=False # 公共仓库无需身份验证
)
sam = sam_model_registry["vit_h"](checkpoint=sam_checkpoint)
sam_predictor = SamPredictor(sam)
# 从 Hugging Face 下载 YOLO-World 权重
yolo_checkpoint = hf_hub_download(
repo_id="stevengrove/YOLO-World", # Hugging Face 仓库 ID
filename="yolo_world_v2_xl_obj365v1_goldg_cc3mlite_pretrain.pth", # 模型权重文件名
use_auth_token=False # 公共仓库无需身份验证
)
# 加载 YOLO-World 配置文件
yolo_config = Config.fromfile('path/to/yolo_world_config.py') # 替换为实际配置文件路径
# 构建 YOLO-World 模型
yolo_model = build_detector(yolo_config.model)
# 加载权重到模型
checkpoint = torch.load(yolo_checkpoint, map_location="cpu") # 使用 CPU 加载权重,后续可以转移到 GPU
yolo_model.load_state_dict(checkpoint["state_dict"])
yolo_model.eval() # 设置为评估模式
wd_processor = AutoProcessor.from_pretrained("SmilingWolf/wd-vit-tagger-v3")
wd_model = AutoModelForImageClassification.from_pretrained("SmilingWolf/wd-vit-tagger-v3")
# 自动识别图片类型
def classify_image_type(image):
inputs = wd_processor(images=image, return_tensors="pt")
outputs = wd_model(**inputs)
scores = torch.softmax(outputs.logits, dim=1)[0]
anime_score = scores[wd_processor.label2id["anime"]].item()
return "anime" if anime_score > 0.5 else "real"
# 分割图像对象
def segment_objects(image, boxes):
image_np = np.array(image)
sam_predictor.set_image(image_np)
masks = []
for box in boxes:
mask, _, _ = sam_predictor.predict(
point_coords=None, point_labels=None, box=box, multimask_output=False
)
masks.append(mask)
return masks
# 检测对象
def detect_objects(image, image_type):
if image_type == "real":
results = yolo_model.predict(np.array(image), conf=0.25)
objects = [{"label": r["class"], "box": r["bbox"], "confidence": r["confidence"]} for r in results]
else:
inputs = wd_processor(images=image, return_tensors="pt")
outputs = wd_model(**inputs)
scores = torch.softmax(outputs.logits, dim=1)[0]
top_k = torch.topk(scores, k=5)
objects = [{"label": wd_processor.decode(top_k.indices[i].item()), "confidence": top_k.values[i].item()} for i in range(5)]
return objects
# 生成语义描述
def generate_object_descriptions(image, objects):
descriptions = []
for obj in objects:
box = obj.get("box", None)
if box:
cropped = image.crop(box)
else:
cropped = image
inputs = blip_processor(cropped, return_tensors="pt")
caption = blip_model.generate(**inputs, max_length=128, num_beams=5, no_repeat_ngram_size=2)
description = blip_processor.decode(caption[0], skip_special_tokens=True)
descriptions.append({"label": obj["label"], "description": description})
return descriptions
# 特征差异可视化
def plot_feature_differences(latent_diff, descriptions, prefix):
diff_magnitude = [abs(x) for x in latent_diff[0]]
indices = range(len(diff_magnitude))
top_indices = np.argsort(diff_magnitude)[-10:][::-1]
plt.figure(figsize=(8, 4))
plt.bar(indices, diff_magnitude, alpha=0.7)
plt.xlabel("Feature Index")
plt.ylabel("Magnitude of Difference")
plt.title("Feature Differences (Bar Chart)")
bar_chart_path = f"{prefix}_bar_chart.png"
plt.savefig(bar_chart_path)
plt.close()
plt.figure(figsize=(6, 6))
plt.pie(
[diff_magnitude[i] for i in top_indices],
labels=[descriptions[i] for i in top_indices],
autopct="%1.1f%%",
startangle=140
)
plt.title("Top 10 Feature Differences (Pie Chart)")
pie_chart_path = f"{prefix}_pie_chart.png"
plt.savefig(pie_chart_path)
plt.close()
return bar_chart_path, pie_chart_path
# 生成详细分析文本
def generate_text_analysis(api_key, api_type, caption_a, caption_b):
if api_type == "DeepSeek":
client = OpenAI(api_key=api_key, base_url="https://api.deepseek.com")
else:
client = OpenAI(api_key=api_key)
response = client.chat.completions.create(
model="gpt-4" if api_type == "GPT" else "deepseek-chat",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": f"图片A的描述为:{caption_a}。\n图片B的描述为:{caption_b}。\n请对两张图片进行详细对比分析。"}
]
)
return response.choices[0].message.content.strip()
# 分析单对图片
def analyze_images(img_a, img_b, api_key, api_type, prefix):
type_a = classify_image_type(img_a)
type_b = classify_image_type(img_b)
objects_a = detect_objects(img_a, type_a)
objects_b = detect_objects(img_b, type_b)
descriptions_a = generate_object_descriptions(img_a, objects_a)
descriptions_b = generate_object_descriptions(img_b, objects_b)
inputs = clip_processor(images=img_a, return_tensors="pt")
features_a = clip_model.get_image_features(**inputs).detach().numpy()
inputs = clip_processor(images=img_b, return_tensors="pt")
features_b = clip_model.get_image_features(**inputs).detach().numpy()
latent_diff = np.abs(features_a - features_b).tolist()
bar_chart, pie_chart = plot_feature_differences(latent_diff, [d['label'] for d in descriptions_a], prefix)
text_analysis = generate_text_analysis(api_key, api_type, descriptions_a, descriptions_b)
return {
"bar_chart": bar_chart,
"pie_chart": pie_chart,
"text_analysis": text_analysis
}
# Gradio 界面
with gr.Blocks() as demo:
gr.Markdown("# 综合图像对比分析工具")
api_key_input = gr.Textbox(label="API Key", placeholder="输入 API Key", type="password")
api_type_input = gr.Radio(label="API 类型", choices=["GPT", "DeepSeek"], value="GPT")
images_a_input = gr.File(label="上传文件夹A图片", file_types=[".png", ".jpg"], file_count="multiple")
images_b_input = gr.File(label="上传文件夹B图片", file_types=[".png", ".jpg"], file_count="multiple")
analyze_button = gr.Button("开始分析")
result_gallery = gr.Gallery(label="差异可视化")
result_text = gr.Textbox(label="分析结果", lines=5)
def process_batch(images_a, images_b, api_key, api_type):
images_a = [Image.open(img).convert("RGB") for img in images_a]
images_b = [Image.open(img).convert("RGB") for img in images_b]
results = [analyze_images(img_a, img_b, api_key, api_type, f"comparison_{i+1}") for i, (img_a, img_b) in enumerate(zip(images_a, images_b))]
return results
analyze_button.click(process_batch, inputs=[images_a_input, images_b_input, api_key_input, api_type_input], outputs=[result_gallery, result_text])
demo.launch() |