|
import gradio as gr |
|
from PIL import Image |
|
import torch |
|
import numpy as np |
|
import faiss |
|
|
|
from transformers import ( |
|
BlipProcessor, |
|
BlipForConditionalGeneration, |
|
CLIPProcessor, |
|
CLIPModel |
|
) |
|
from datasets import load_dataset |
|
|
|
wikiart_dataset = load_dataset("huggan/wikiart", split="train", streaming=True) |
|
|
|
def get_item_streaming(dataset, idx): |
|
for i, item in enumerate(dataset): |
|
if i == idx: |
|
return item |
|
raise IndexError("Index out of range") |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") |
|
|
|
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") |
|
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device).eval() |
|
|
|
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device).eval() |
|
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") |
|
|
|
image_index = faiss.read_index("image_index.faiss") |
|
text_index = faiss.read_index("text_index.faiss") |
|
|
|
def generate_caption(image: Image.Image): |
|
inputs = blip_processor(image, return_tensors="pt").to(device) |
|
with torch.no_grad(): |
|
caption_ids = blip_model.generate(**inputs) |
|
caption = blip_processor.decode(caption_ids[0], skip_special_tokens=True) |
|
return caption |
|
|
|
def get_clip_text_embedding(text): |
|
inputs = clip_processor(text=[text], return_tensors="pt", padding=True).to(device) |
|
with torch.no_grad(): |
|
features = clip_model.get_text_features(**inputs) |
|
features = features.cpu().numpy().astype("float32") |
|
faiss.normalize_L2(features) |
|
return features |
|
|
|
def get_clip_image_embedding(image): |
|
inputs = clip_processor(images=image, return_tensors="pt").to(device) |
|
with torch.no_grad(): |
|
features = clip_model.get_image_features(**inputs) |
|
features = features.cpu().numpy().astype("float32") |
|
faiss.normalize_L2(features) |
|
return features |
|
|
|
def get_results_with_images(embedding, index, top_k=2): |
|
D, I = index.search(embedding, top_k) |
|
results = [] |
|
for idx in I[0]: |
|
try: |
|
item = get_item_streaming(wikiart_dataset, int(idx)) |
|
img = item["image"] |
|
title = item.get("title", "Untitled") |
|
artist = item.get("artist", "Unknown") |
|
caption = f"ID: {idx}\n{title} — {artist}" |
|
results.append((img, caption)) |
|
except IndexError: |
|
continue |
|
return results |
|
|
|
|
|
def search_similar_images(image: Image.Image): |
|
caption = generate_caption(image) |
|
text_emb = get_clip_text_embedding(caption) |
|
image_emb = get_clip_image_embedding(image) |
|
|
|
text_results = get_results_with_images(text_emb, text_index) |
|
image_results = get_results_with_images(image_emb, image_index) |
|
|
|
return caption, text_results, image_results |
|
|
|
demo = gr.Interface( |
|
fn=search_similar_images, |
|
inputs=gr.Image(label="Загрузите изображение", type="pil"), |
|
outputs=[ |
|
gr.Textbox(label="📜 Сгенерированное описание"), |
|
gr.Gallery(label="🔍 Похожие по описанию (CLIP)", height="auto", columns=2), |
|
gr.Gallery(label="🎨 Похожие по изображению (CLIP)", height="auto", columns=2) |
|
], |
|
title="🎨 Semantic WikiArt Search (BLIP + CLIP)", |
|
description="Загрузите изображение. Модель BLIP сгенерирует описание, а CLIP найдёт похожие картины по тексту и изображению." |
|
) |
|
|
|
demo.launch() |
|
|