File size: 3,656 Bytes
c6ad3ab ed1bc19 3bda56e c6ad3ab 3bda56e c6ad3ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import gradio as gr
from PIL import Image
import torch
import numpy as np
import faiss
from transformers import (
BlipProcessor,
BlipForConditionalGeneration,
CLIPProcessor,
CLIPModel
)
from datasets import load_dataset
wikiart_dataset = load_dataset("huggan/wikiart", split="train", streaming=True)
def get_item_streaming(dataset, idx):
for i, item in enumerate(dataset):
if i == idx:
return item
raise IndexError("Index out of range")
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device).eval()
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device).eval()
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
image_index = faiss.read_index("image_index.faiss")
text_index = faiss.read_index("text_index.faiss")
def generate_caption(image: Image.Image):
inputs = blip_processor(image, return_tensors="pt").to(device)
with torch.no_grad():
caption_ids = blip_model.generate(**inputs)
caption = blip_processor.decode(caption_ids[0], skip_special_tokens=True)
return caption
def get_clip_text_embedding(text):
inputs = clip_processor(text=[text], return_tensors="pt", padding=True).to(device)
with torch.no_grad():
features = clip_model.get_text_features(**inputs)
features = features.cpu().numpy().astype("float32")
faiss.normalize_L2(features)
return features
def get_clip_image_embedding(image):
inputs = clip_processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
features = clip_model.get_image_features(**inputs)
features = features.cpu().numpy().astype("float32")
faiss.normalize_L2(features)
return features
def get_results_with_images(embedding, index, top_k=2):
D, I = index.search(embedding, top_k)
results = []
for idx in I[0]:
try:
item = get_item_streaming(wikiart_dataset, int(idx))
img = item["image"]
title = item.get("title", "Untitled")
artist = item.get("artist", "Unknown")
caption = f"ID: {idx}\n{title} — {artist}"
results.append((img, caption))
except IndexError:
continue
return results
# Основная функция поиска
def search_similar_images(image: Image.Image):
caption = generate_caption(image)
text_emb = get_clip_text_embedding(caption)
image_emb = get_clip_image_embedding(image)
text_results = get_results_with_images(text_emb, text_index)
image_results = get_results_with_images(image_emb, image_index)
return caption, text_results, image_results
demo = gr.Interface(
fn=search_similar_images,
inputs=gr.Image(label="Загрузите изображение", type="pil"),
outputs=[
gr.Textbox(label="📜 Сгенерированное описание"),
gr.Gallery(label="🔍 Похожие по описанию (CLIP)", height="auto", columns=2),
gr.Gallery(label="🎨 Похожие по изображению (CLIP)", height="auto", columns=2)
],
title="🎨 Semantic WikiArt Search (BLIP + CLIP)",
description="Загрузите изображение. Модель BLIP сгенерирует описание, а CLIP найдёт похожие картины по тексту и изображению."
)
demo.launch()
|