LeBuH commited on
Commit
a05ebfa
·
verified ·
1 Parent(s): 39c508d

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -95
app.py DELETED
@@ -1,95 +0,0 @@
1
- import gradio as gr
2
- from PIL import Image
3
- import torch
4
- import numpy as np
5
- import faiss
6
-
7
- from transformers import (
8
- BlipProcessor,
9
- BlipForConditionalGeneration,
10
- CLIPProcessor,
11
- CLIPModel
12
- )
13
- from datasets import load_dataset
14
-
15
- wikiart_dataset = load_dataset("huggan/wikiart", split="train", streaming=True)
16
-
17
- def get_item_streaming(dataset, idx):
18
- for i, item in enumerate(dataset):
19
- if i == idx:
20
- return item
21
- raise IndexError("Index out of range")
22
-
23
- device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
24
-
25
- blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
26
- blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device).eval()
27
-
28
- clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device).eval()
29
- clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
30
-
31
- image_index = faiss.read_index("image_index.faiss")
32
- text_index = faiss.read_index("text_index.faiss")
33
-
34
- def generate_caption(image: Image.Image):
35
- inputs = blip_processor(image, return_tensors="pt").to(device)
36
- with torch.no_grad():
37
- caption_ids = blip_model.generate(**inputs)
38
- caption = blip_processor.decode(caption_ids[0], skip_special_tokens=True)
39
- return caption
40
-
41
- def get_clip_text_embedding(text):
42
- inputs = clip_processor(text=[text], return_tensors="pt", padding=True).to(device)
43
- with torch.no_grad():
44
- features = clip_model.get_text_features(**inputs)
45
- features = features.cpu().numpy().astype("float32")
46
- faiss.normalize_L2(features)
47
- return features
48
-
49
- def get_clip_image_embedding(image):
50
- inputs = clip_processor(images=image, return_tensors="pt").to(device)
51
- with torch.no_grad():
52
- features = clip_model.get_image_features(**inputs)
53
- features = features.cpu().numpy().astype("float32")
54
- faiss.normalize_L2(features)
55
- return features
56
-
57
- def get_results_with_images(embedding, index, top_k=2):
58
- D, I = index.search(embedding, top_k)
59
- results = []
60
- for idx in I[0]:
61
- try:
62
- item = get_item_streaming(wikiart_dataset, int(idx))
63
- img = item["image"]
64
- title = item.get("title", "Untitled")
65
- artist = item.get("artist", "Unknown")
66
- caption = f"ID: {idx}\n{title} — {artist}"
67
- results.append((img, caption))
68
- except IndexError:
69
- continue
70
- return results
71
-
72
- # Основная функция поиска
73
- def search_similar_images(image: Image.Image):
74
- caption = generate_caption(image)
75
- text_emb = get_clip_text_embedding(caption)
76
- image_emb = get_clip_image_embedding(image)
77
-
78
- text_results = get_results_with_images(text_emb, text_index)
79
- image_results = get_results_with_images(image_emb, image_index)
80
-
81
- return caption, text_results, image_results
82
-
83
- demo = gr.Interface(
84
- fn=search_similar_images,
85
- inputs=gr.Image(label="Загрузите изображение", type="pil"),
86
- outputs=[
87
- gr.Textbox(label="📜 Сгенерированное описание"),
88
- gr.Gallery(label="🔍 Похожие по описанию (CLIP)", height="auto", columns=2),
89
- gr.Gallery(label="🎨 Похожие по изображению (CLIP)", height="auto", columns=2)
90
- ],
91
- title="🎨 Semantic WikiArt Search (BLIP + CLIP)",
92
- description="Загрузите изображение. Модель BLIP сгенерирует описание, а CLIP найдёт похожие картины по тексту и изображению."
93
- )
94
-
95
- demo.launch()