PSNbst's picture
Update app.py
b80196d verified
raw
history blame
3.64 kB
import gradio as gr
import torch
from transformers import CLIPProcessor, CLIPModel, BlipProcessor, BlipForConditionalGeneration
from PIL import Image
import numpy as np
import openai # GPT API 调用
# 初始化模型
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
# GPT API 配置
openai.api_key = "your_openai_api_key"
# 定义功能函数
def analyze_images(image_a, image_b):
# BLIP生成描述
def generate_caption(image):
inputs = blip_processor(image, return_tensors="pt")
caption = blip_model.generate(**inputs)
return blip_processor.decode(caption[0], skip_special_tokens=True)
# CLIP特征提取
def extract_features(image):
inputs = clip_processor(images=image, return_tensors="pt")
features = clip_model.get_image_features(**inputs)
return features.detach().numpy()
# 加载图片
img_a = Image.open(image_a).convert("RGB")
img_b = Image.open(image_b).convert("RGB")
# 生成描述
caption_a = generate_caption(img_a)
caption_b = generate_caption(img_b)
# 提取特征
features_a = extract_features(img_a)
features_b = extract_features(img_b)
# 计算嵌入相似性
cosine_similarity = np.dot(features_a, features_b.T) / (np.linalg.norm(features_a) * np.linalg.norm(features_b))
latent_diff = np.abs(features_a - features_b).tolist()
# GPT API 调用生成文字描述
gpt_prompt = (
f"图片A的描述为:{caption_a}。图片B的描述为:{caption_b}。\n"
"请对两张图片的内容和潜在特征区别进行详细分析,并输出一个简洁但富有条理的总结。"
)
gpt_response = openai.Completion.create(
engine="text-davinci-003",
prompt=gpt_prompt,
max_tokens=150
)
textual_analysis = gpt_response['choices'][0]['text'].strip()
# 返回结果
return {
"caption_a": caption_a,
"caption_b": caption_b,
"similarity": cosine_similarity[0][0],
"latent_diff": latent_diff,
"text_analysis": textual_analysis
}
# 定义Gradio界面
with gr.Blocks() as demo:
gr.Markdown("# 图片对比分析工具")
with gr.Row():
with gr.Column():
image_a = gr.Image(label="图片A", type="pil") # 使用 PIL 类型
with gr.Column():
image_b = gr.Image(label="图片B", type="pil") # 使用 PIL 类型
analyze_button = gr.Button("分析图片")
result_caption_a = gr.Textbox(label="图片A描述", interactive=False)
result_caption_b = gr.Textbox(label="图片B描述", interactive=False)
result_similarity = gr.Number(label="图片相似性", interactive=False)
result_latent_diff = gr.DataFrame(label="潜在特征差异", interactive=False)
result_text_analysis = gr.Textbox(label="详细分析", interactive=False, lines=5)
# 分析逻辑
def process_analysis(img_a, img_b):
results = analyze_images(img_a, img_b)
return results["caption_a"], results["caption_b"], results["similarity"], results["latent_diff"], results["text_analysis"]
analyze_button.click(
fn=process_analysis,
inputs=[image_a, image_b],
outputs=[result_caption_a, result_caption_b, result_similarity, result_latent_diff, result_text_analysis]
)
demo.launch()