import gradio as gr from PIL import Image import torch import numpy as np import faiss from transformers import ( BlipProcessor, BlipForConditionalGeneration, CLIPProcessor, CLIPModel ) from datasets import load_dataset wikiart_dataset = load_dataset("huggan/wikiart", split="train", streaming=True) def get_item_streaming(dataset, idx): for i, item in enumerate(dataset): if i == idx: return item raise IndexError("Index out of range") device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device).eval() clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device).eval() clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") image_index = faiss.read_index("image_index.faiss") text_index = faiss.read_index("text_index.faiss") def generate_caption(image: Image.Image): inputs = blip_processor(image, return_tensors="pt").to(device) with torch.no_grad(): caption_ids = blip_model.generate(**inputs) caption = blip_processor.decode(caption_ids[0], skip_special_tokens=True) return caption def get_clip_text_embedding(text): inputs = clip_processor(text=[text], return_tensors="pt", padding=True).to(device) with torch.no_grad(): features = clip_model.get_text_features(**inputs) features = features.cpu().numpy().astype("float32") faiss.normalize_L2(features) return features def get_clip_image_embedding(image): inputs = clip_processor(images=image, return_tensors="pt").to(device) with torch.no_grad(): features = clip_model.get_image_features(**inputs) features = features.cpu().numpy().astype("float32") faiss.normalize_L2(features) return features def get_results_with_images(embedding, index, top_k=2): D, I = index.search(embedding, top_k) results = [] for idx in I[0]: try: item = get_item_streaming(wikiart_dataset, int(idx)) img = item["image"] title = item.get("title", "Untitled") artist = item.get("artist", "Unknown") caption = f"ID: {idx}\n{title} — {artist}" results.append((img, caption)) except IndexError: continue return results # Основная функция поиска def search_similar_images(image: Image.Image): caption = generate_caption(image) text_emb = get_clip_text_embedding(caption) image_emb = get_clip_image_embedding(image) text_results = get_results_with_images(text_emb, text_index) image_results = get_results_with_images(image_emb, image_index) return caption, text_results, image_results demo = gr.Interface( fn=search_similar_images, inputs=gr.Image(label="Загрузите изображение", type="pil"), outputs=[ gr.Textbox(label="📜 Сгенерированное описание"), gr.Gallery(label="🔍 Похожие по описанию (CLIP)", height="auto", columns=2), gr.Gallery(label="🎨 Похожие по изображению (CLIP)", height="auto", columns=2) ], title="🎨 Semantic WikiArt Search (BLIP + CLIP)", description="Загрузите изображение. Модель BLIP сгенерирует описание, а CLIP найдёт похожие картины по тексту и изображению." ) demo.launch()