PSNbst commited on
Commit
9c08ff8
·
verified ·
1 Parent(s): 85ae7eb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -103
app.py CHANGED
@@ -1,7 +1,9 @@
1
  import gradio as gr
2
  import os
3
  from PIL import Image, ImageChops, ImageFilter
4
- from transformers import CLIPProcessor, CLIPModel, BlipProcessor, BlipForConditionalGeneration
 
 
5
  import torch
6
  import matplotlib.pyplot as plt
7
  import numpy as np
@@ -12,57 +14,70 @@ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
12
  clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
13
  blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
14
  blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
15
-
16
- # 定义CLIP特征的名称(假设的特征名称,您可以根据需要调整)
17
- CLIP_FEATURE_NAMES = [f"Dimension {i}" for i in range(512)]
18
-
19
- # 图像处理函数
20
- def compute_difference_images(img_a, img_b):
21
- def extract_sketch(image):
22
- grayscale = image.convert("L")
23
- inverted = ImageChops.invert(grayscale)
24
- sketch = ImageChops.screen(grayscale, inverted)
25
- return sketch
26
-
27
- def compute_normal_map(image):
28
- edges = image.filter(ImageFilter.FIND_EDGES)
29
- return edges
30
-
31
- diff_overlay = ImageChops.difference(img_a, img_b)
32
- return {
33
- "original_a": img_a,
34
- "original_b": img_b,
35
- "sketch_a": extract_sketch(img_a),
36
- "sketch_b": extract_sketch(img_b),
37
- "normal_a": compute_normal_map(img_a),
38
- "normal_b": compute_normal_map(img_b),
39
- "diff_overlay": diff_overlay
40
- }
41
-
42
- # 保存图像到文件
43
- def save_images(images, prefix):
44
- paths = []
45
- for key, img in images.items():
46
- path = f"{prefix}_{key}.png"
47
- img.save(path)
48
- paths.append((path, key.replace("_", " ").capitalize()))
49
- return paths
50
-
51
- # BLIP生成更详尽描述
52
- def generate_detailed_caption(image):
53
- inputs = blip_processor(image, return_tensors="pt")
54
- caption = blip_model.generate(**inputs, max_length=128, num_beams=5, no_repeat_ngram_size=2)
55
- return blip_processor.decode(caption[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  # 特征差异可视化
58
- def plot_feature_differences(latent_diff, prefix):
59
  diff_magnitude = [abs(x) for x in latent_diff[0]]
60
  indices = range(len(diff_magnitude))
61
- top_indices = np.argsort(diff_magnitude)[-10:][::-1] # Top 10 differences
62
 
63
  plt.figure(figsize=(8, 4))
64
  plt.bar(indices, diff_magnitude, alpha=0.7)
65
- plt.xlabel("Feature Index (Latent Dimension)")
66
  plt.ylabel("Magnitude of Difference")
67
  plt.title("Feature Differences (Bar Chart)")
68
  bar_chart_path = f"{prefix}_bar_chart.png"
@@ -72,7 +87,7 @@ def plot_feature_differences(latent_diff, prefix):
72
  plt.figure(figsize=(6, 6))
73
  plt.pie(
74
  [diff_magnitude[i] for i in top_indices],
75
- labels=[CLIP_FEATURE_NAMES[i] for i in top_indices],
76
  autopct="%1.1f%%",
77
  startangle=140
78
  )
@@ -83,7 +98,7 @@ def plot_feature_differences(latent_diff, prefix):
83
 
84
  return bar_chart_path, pie_chart_path
85
 
86
- # 生成详细分析
87
  def generate_text_analysis(api_key, api_type, caption_a, caption_b):
88
  if api_type == "DeepSeek":
89
  client = OpenAI(api_key=api_key, base_url="https://api.deepseek.com")
@@ -94,18 +109,21 @@ def generate_text_analysis(api_key, api_type, caption_a, caption_b):
94
  model="gpt-4" if api_type == "GPT" else "deepseek-chat",
95
  messages=[
96
  {"role": "system", "content": "You are a helpful assistant."},
97
- {"role": "user", "content": f"图片A的描述为:{caption_a}。图片B的描述为:{caption_b}。\n请对两张图片的内容和潜在特征区别进行详细分析,并输出一个简洁但富有条理的总结。"}
98
  ]
99
  )
100
  return response.choices[0].message.content.strip()
101
 
102
- # 分析函数
103
  def analyze_images(img_a, img_b, api_key, api_type, prefix):
104
- images_diff = compute_difference_images(img_a, img_b)
105
- saved_images = save_images(images_diff, prefix)
 
 
 
106
 
107
- caption_a = generate_detailed_caption(img_a)
108
- caption_b = generate_detailed_caption(img_b)
109
 
110
  inputs = clip_processor(images=img_a, return_tensors="pt")
111
  features_a = clip_model.get_image_features(**inputs).detach().numpy()
@@ -115,66 +133,32 @@ def analyze_images(img_a, img_b, api_key, api_type, prefix):
115
 
116
  latent_diff = np.abs(features_a - features_b).tolist()
117
 
118
- bar_chart, pie_chart = plot_feature_differences(latent_diff, prefix)
119
- text_analysis = generate_text_analysis(api_key, api_type, caption_a, caption_b)
120
 
121
  return {
122
- "saved_images": saved_images,
123
- "caption_a": caption_a,
124
- "caption_b": caption_b,
125
- "text_analysis": text_analysis,
126
  "bar_chart": bar_chart,
127
- "pie_chart": pie_chart
 
128
  }
129
 
130
- # 批量分析
131
- def batch_analyze(images_a, images_b, api_key, api_type):
132
- num_pairs = min(len(images_a), len(images_b))
133
-
134
- results = []
135
- for i in range(num_pairs):
136
- prefix = f"comparison_{i+1}"
137
- result = analyze_images(images_a[i], images_b[i], api_key, api_type, prefix)
138
- results.append({
139
- "pair": (f"Image A-{i+1}", f"Image B-{i+1}"),
140
- **result
141
- })
142
- return results
143
-
144
- # Gradio界面
145
  with gr.Blocks() as demo:
146
- gr.Markdown("# 批量图像对比分析工具")
147
-
148
- api_key_input = gr.Textbox(label="API Key", placeholder="输入您的 API Key", type="password")
149
  api_type_input = gr.Radio(label="API 类型", choices=["GPT", "DeepSeek"], value="GPT")
150
- images_a_input = gr.File(label="上传文件夹A图片", file_types=[".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".gif", ".webp"], file_count="multiple")
151
- images_b_input = gr.File(label="上传文件夹B图片", file_types=[".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".gif", ".webp"], file_count="multiple")
152
- analyze_button = gr.Button("开始批量分析")
153
-
154
- with gr.Row():
155
- result_gallery = gr.Gallery(label="差异图像")
156
- result_text_analysis = gr.Textbox(label="详细分析", interactive=False, lines=5)
157
 
158
- def process_batch_analysis(images_a, images_b, api_key, api_type):
159
  images_a = [Image.open(img).convert("RGB") for img in images_a]
160
  images_b = [Image.open(img).convert("RGB") for img in images_b]
161
- results = batch_analyze(images_a, images_b, api_key, api_type)
162
-
163
- all_images = []
164
- all_texts = []
165
 
166
- for result in results:
167
- all_images.extend(result["saved_images"])
168
- all_images.append((result["bar_chart"], "Bar Chart"))
169
- all_images.append((result["pie_chart"], "Pie Chart"))
170
- all_texts.append(f"{result['pair'][0]} vs {result['pair'][1]}:\n{result['text_analysis']}")
171
-
172
- return all_images, "\n\n".join(all_texts)
173
-
174
- analyze_button.click(
175
- fn=process_batch_analysis,
176
- inputs=[images_a_input, images_b_input, api_key_input, api_type_input],
177
- outputs=[result_gallery, result_text_analysis]
178
- )
179
 
180
  demo.launch()
 
1
  import gradio as gr
2
  import os
3
  from PIL import Image, ImageChops, ImageFilter
4
+ from ultralytics import YOLO
5
+ from segment_anything import SamPredictor, sam_model_registry
6
+ from transformers import BlipProcessor, BlipForConditionalGeneration, CLIPProcessor, CLIPModel, AutoProcessor, AutoModelForImageClassification
7
  import torch
8
  import matplotlib.pyplot as plt
9
  import numpy as np
 
14
  clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
15
  blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
16
  blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
17
+ sam_checkpoint = "sam_vit_h_4b8939.pth" # 替换为实际权重路径
18
+ sam = sam_model_registry["vit_h"](checkpoint=sam_checkpoint)
19
+ sam_predictor = SamPredictor(sam)
20
+ yolo_model = YOLO("yolov8x.pt") # 替换为实际 YOLO 模型路径
21
+ wd_processor = AutoProcessor.from_pretrained("SmilingWolf/wd-v1-4-vit-large-tagger")
22
+ wd_model = AutoModelForImageClassification.from_pretrained("SmilingWolf/wd-v1-4-vit-large-tagger")
23
+
24
+ # 自动识别图片类型
25
+ def classify_image_type(image):
26
+ inputs = wd_processor(images=image, return_tensors="pt")
27
+ outputs = wd_model(**inputs)
28
+ scores = torch.softmax(outputs.logits, dim=1)[0]
29
+ anime_score = scores[wd_processor.label2id["anime"]].item()
30
+ return "anime" if anime_score > 0.5 else "real"
31
+
32
+ # 分割图像对象
33
+ def segment_objects(image, boxes):
34
+ image_np = np.array(image)
35
+ sam_predictor.set_image(image_np)
36
+ masks = []
37
+ for box in boxes:
38
+ mask, _, _ = sam_predictor.predict(
39
+ point_coords=None, point_labels=None, box=box, multimask_output=False
40
+ )
41
+ masks.append(mask)
42
+ return masks
43
+
44
+ # 检测对象
45
+ def detect_objects(image, image_type):
46
+ if image_type == "real":
47
+ results = yolo_model.predict(np.array(image), conf=0.25)
48
+ objects = [{"label": r["class"], "box": r["bbox"], "confidence": r["confidence"]} for r in results]
49
+ else:
50
+ inputs = wd_processor(images=image, return_tensors="pt")
51
+ outputs = wd_model(**inputs)
52
+ scores = torch.softmax(outputs.logits, dim=1)[0]
53
+ top_k = torch.topk(scores, k=5)
54
+ objects = [{"label": wd_processor.decode(top_k.indices[i].item()), "confidence": top_k.values[i].item()} for i in range(5)]
55
+ return objects
56
+
57
+ # 生成语义描述
58
+ def generate_object_descriptions(image, objects):
59
+ descriptions = []
60
+ for obj in objects:
61
+ box = obj.get("box", None)
62
+ if box:
63
+ cropped = image.crop(box)
64
+ else:
65
+ cropped = image
66
+ inputs = blip_processor(cropped, return_tensors="pt")
67
+ caption = blip_model.generate(**inputs, max_length=128, num_beams=5, no_repeat_ngram_size=2)
68
+ description = blip_processor.decode(caption[0], skip_special_tokens=True)
69
+ descriptions.append({"label": obj["label"], "description": description})
70
+ return descriptions
71
 
72
  # 特征差异可视化
73
+ def plot_feature_differences(latent_diff, descriptions, prefix):
74
  diff_magnitude = [abs(x) for x in latent_diff[0]]
75
  indices = range(len(diff_magnitude))
76
+ top_indices = np.argsort(diff_magnitude)[-10:][::-1]
77
 
78
  plt.figure(figsize=(8, 4))
79
  plt.bar(indices, diff_magnitude, alpha=0.7)
80
+ plt.xlabel("Feature Index")
81
  plt.ylabel("Magnitude of Difference")
82
  plt.title("Feature Differences (Bar Chart)")
83
  bar_chart_path = f"{prefix}_bar_chart.png"
 
87
  plt.figure(figsize=(6, 6))
88
  plt.pie(
89
  [diff_magnitude[i] for i in top_indices],
90
+ labels=[descriptions[i] for i in top_indices],
91
  autopct="%1.1f%%",
92
  startangle=140
93
  )
 
98
 
99
  return bar_chart_path, pie_chart_path
100
 
101
+ # 生成详细分析文本
102
  def generate_text_analysis(api_key, api_type, caption_a, caption_b):
103
  if api_type == "DeepSeek":
104
  client = OpenAI(api_key=api_key, base_url="https://api.deepseek.com")
 
109
  model="gpt-4" if api_type == "GPT" else "deepseek-chat",
110
  messages=[
111
  {"role": "system", "content": "You are a helpful assistant."},
112
+ {"role": "user", "content": f"图片A的描述为:{caption_a}。\n图片B的描述为:{caption_b}。\n请对两张图片进行详细对比分析。"}
113
  ]
114
  )
115
  return response.choices[0].message.content.strip()
116
 
117
+ # 分析单对图片
118
  def analyze_images(img_a, img_b, api_key, api_type, prefix):
119
+ type_a = classify_image_type(img_a)
120
+ type_b = classify_image_type(img_b)
121
+
122
+ objects_a = detect_objects(img_a, type_a)
123
+ objects_b = detect_objects(img_b, type_b)
124
 
125
+ descriptions_a = generate_object_descriptions(img_a, objects_a)
126
+ descriptions_b = generate_object_descriptions(img_b, objects_b)
127
 
128
  inputs = clip_processor(images=img_a, return_tensors="pt")
129
  features_a = clip_model.get_image_features(**inputs).detach().numpy()
 
133
 
134
  latent_diff = np.abs(features_a - features_b).tolist()
135
 
136
+ bar_chart, pie_chart = plot_feature_differences(latent_diff, [d['label'] for d in descriptions_a], prefix)
137
+ text_analysis = generate_text_analysis(api_key, api_type, descriptions_a, descriptions_b)
138
 
139
  return {
 
 
 
 
140
  "bar_chart": bar_chart,
141
+ "pie_chart": pie_chart,
142
+ "text_analysis": text_analysis
143
  }
144
 
145
+ # Gradio 界面
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  with gr.Blocks() as demo:
147
+ gr.Markdown("# 综合图像对比分析工具")
148
+ api_key_input = gr.Textbox(label="API Key", placeholder="输入 API Key", type="password")
 
149
  api_type_input = gr.Radio(label="API 类型", choices=["GPT", "DeepSeek"], value="GPT")
150
+ images_a_input = gr.File(label="上传文件夹A图片", file_types=[".png", ".jpg"], file_count="multiple")
151
+ images_b_input = gr.File(label="上传文件夹B图片", file_types=[".png", ".jpg"], file_count="multiple")
152
+ analyze_button = gr.Button("开始分析")
153
+ result_gallery = gr.Gallery(label="差异可视化")
154
+ result_text = gr.Textbox(label="分析结果", lines=5)
 
 
155
 
156
+ def process_batch(images_a, images_b, api_key, api_type):
157
  images_a = [Image.open(img).convert("RGB") for img in images_a]
158
  images_b = [Image.open(img).convert("RGB") for img in images_b]
159
+ results = [analyze_images(img_a, img_b, api_key, api_type, f"comparison_{i+1}") for i, (img_a, img_b) in enumerate(zip(images_a, images_b))]
160
+ return results
 
 
161
 
162
+ analyze_button.click(process_batch, inputs=[images_a_input, images_b_input, api_key_input, api_type_input], outputs=[result_gallery, result_text])
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
  demo.launch()