File size: 3,656 Bytes
c6ad3ab
 
 
 
 
 
 
 
 
 
 
 
 
 
ed1bc19
3bda56e
 
 
 
 
 
 
c6ad3ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3bda56e
c6ad3ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import gradio as gr
from PIL import Image
import torch
import numpy as np
import faiss

from transformers import (
    BlipProcessor,
    BlipForConditionalGeneration,
    CLIPProcessor,
    CLIPModel
)
from datasets import load_dataset

wikiart_dataset = load_dataset("huggan/wikiart", split="train", streaming=True)

def get_item_streaming(dataset, idx):
    for i, item in enumerate(dataset):
        if i == idx:
            return item
    raise IndexError("Index out of range")

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device).eval()

clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device).eval()
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

image_index = faiss.read_index("image_index.faiss")
text_index = faiss.read_index("text_index.faiss")

def generate_caption(image: Image.Image):
    inputs = blip_processor(image, return_tensors="pt").to(device)
    with torch.no_grad():
        caption_ids = blip_model.generate(**inputs)
    caption = blip_processor.decode(caption_ids[0], skip_special_tokens=True)
    return caption

def get_clip_text_embedding(text):
    inputs = clip_processor(text=[text], return_tensors="pt", padding=True).to(device)
    with torch.no_grad():
        features = clip_model.get_text_features(**inputs)
    features = features.cpu().numpy().astype("float32")
    faiss.normalize_L2(features)
    return features

def get_clip_image_embedding(image):
    inputs = clip_processor(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        features = clip_model.get_image_features(**inputs)
    features = features.cpu().numpy().astype("float32")
    faiss.normalize_L2(features)
    return features

def get_results_with_images(embedding, index, top_k=2):
    D, I = index.search(embedding, top_k)
    results = []
    for idx in I[0]:
        try:
            item = get_item_streaming(wikiart_dataset, int(idx))
            img = item["image"]
            title = item.get("title", "Untitled")
            artist = item.get("artist", "Unknown")
            caption = f"ID: {idx}\n{title}{artist}"
            results.append((img, caption))
        except IndexError:
            continue
    return results

# Основная функция поиска
def search_similar_images(image: Image.Image):
    caption = generate_caption(image)
    text_emb = get_clip_text_embedding(caption)
    image_emb = get_clip_image_embedding(image)

    text_results = get_results_with_images(text_emb, text_index)
    image_results = get_results_with_images(image_emb, image_index)

    return caption, text_results, image_results

demo = gr.Interface(
    fn=search_similar_images,
    inputs=gr.Image(label="Загрузите изображение", type="pil"),
    outputs=[
        gr.Textbox(label="📜 Сгенерированное описание"),
        gr.Gallery(label="🔍 Похожие по описанию (CLIP)", height="auto", columns=2),
        gr.Gallery(label="🎨 Похожие по изображению (CLIP)", height="auto", columns=2)
    ],
    title="🎨 Semantic WikiArt Search (BLIP + CLIP)",
    description="Загрузите изображение. Модель BLIP сгенерирует описание, а CLIP найдёт похожие картины по тексту и изображению."
)

demo.launch()