LeBuH commited on
Commit
c6ad3ab
·
verified ·
1 Parent(s): fd04905

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -325
app.py CHANGED
@@ -1,325 +1,95 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 2,
6
- "id": "78ab80c4-8e25-4464-b710-087d385349fe",
7
- "metadata": {},
8
- "outputs": [
9
- {
10
- "name": "stderr",
11
- "output_type": "stream",
12
- "text": [
13
- "/opt/homebrew/Cellar/jupyterlab/4.4.0/libexec/lib/python3.13/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
14
- " from .autonotebook import tqdm as notebook_tqdm\n"
15
- ]
16
- }
17
- ],
18
- "source": [
19
- "import gradio as gr\n",
20
- "from PIL import Image\n",
21
- "import torch\n",
22
- "import numpy as np\n",
23
- "import faiss\n",
24
- "import json\n",
25
- "\n",
26
- "from transformers import (\n",
27
- " BlipProcessor,\n",
28
- " BlipForConditionalGeneration,\n",
29
- " CLIPProcessor,\n",
30
- " CLIPModel\n",
31
- ")\n",
32
- "from datasets import load_dataset"
33
- ]
34
- },
35
- {
36
- "cell_type": "code",
37
- "execution_count": 3,
38
- "id": "9e6fe9c1-df25-41ad-ab27-f6fc20ecb956",
39
- "metadata": {},
40
- "outputs": [],
41
- "source": [
42
- "wikiart_dataset = load_dataset(\"huggan/wikiart\", split=\"train\")\n",
43
- "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"mps\" if torch.backends.mps.is_available() else \"cpu\")"
44
- ]
45
- },
46
- {
47
- "cell_type": "code",
48
- "execution_count": 4,
49
- "id": "b9da3ff0-62e6-4686-af9f-38183f675788",
50
- "metadata": {},
51
- "outputs": [
52
- {
53
- "name": "stderr",
54
- "output_type": "stream",
55
- "text": [
56
- "Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.\n"
57
- ]
58
- }
59
- ],
60
- "source": [
61
- "blip_processor = BlipProcessor.from_pretrained(\"Salesforce/blip-image-captioning-base\")\n",
62
- "blip_model = BlipForConditionalGeneration.from_pretrained(\"Salesforce/blip-image-captioning-base\").to(device).eval()"
63
- ]
64
- },
65
- {
66
- "cell_type": "code",
67
- "execution_count": 5,
68
- "id": "12d9402a-fdbe-4ade-99ed-26f5d7f9ccfd",
69
- "metadata": {},
70
- "outputs": [],
71
- "source": [
72
- "clip_model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\").to(device).eval()\n",
73
- "clip_processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")"
74
- ]
75
- },
76
- {
77
- "cell_type": "code",
78
- "execution_count": 6,
79
- "id": "d4f5e7b2-c873-4495-8ad1-9e32f4f1fbe1",
80
- "metadata": {},
81
- "outputs": [],
82
- "source": [
83
- "with open(\"../create_embeddings/wikiart_embeddings.json\", \"r\", encoding=\"utf-8\") as f:\n",
84
- " data = json.load(f)"
85
- ]
86
- },
87
- {
88
- "cell_type": "code",
89
- "execution_count": 7,
90
- "id": "87bc4121-f316-4769-bf5d-197db30fe2a3",
91
- "metadata": {},
92
- "outputs": [],
93
- "source": [
94
- "image_index = faiss.read_index(\"../create_index/image_index.faiss\")\n",
95
- "text_index = faiss.read_index(\"../create_index/text_index.faiss\")"
96
- ]
97
- },
98
- {
99
- "cell_type": "code",
100
- "execution_count": 8,
101
- "id": "b41d1e5c-d606-4501-a22c-3cde576361d7",
102
- "metadata": {},
103
- "outputs": [],
104
- "source": [
105
- "def generate_caption(image: Image.Image):\n",
106
- " inputs = blip_processor(image, return_tensors=\"pt\").to(device)\n",
107
- " with torch.no_grad():\n",
108
- " caption_ids = blip_model.generate(**inputs)\n",
109
- " caption = blip_processor.decode(caption_ids[0], skip_special_tokens=True)\n",
110
- " return caption"
111
- ]
112
- },
113
- {
114
- "cell_type": "code",
115
- "execution_count": 9,
116
- "id": "263c8672-f4b4-46b7-abf0-483ccfb83c86",
117
- "metadata": {},
118
- "outputs": [],
119
- "source": [
120
- "def get_clip_text_embedding(text):\n",
121
- " inputs = clip_processor(text=[text], return_tensors=\"pt\", padding=True).to(device)\n",
122
- " with torch.no_grad():\n",
123
- " features = clip_model.get_text_features(**inputs)\n",
124
- " features = features.cpu().numpy().astype(\"float32\")\n",
125
- " faiss.normalize_L2(features)\n",
126
- " return features"
127
- ]
128
- },
129
- {
130
- "cell_type": "code",
131
- "execution_count": 10,
132
- "id": "34827bd8-e0da-4252-b168-3c79f2d2fb02",
133
- "metadata": {},
134
- "outputs": [],
135
- "source": [
136
- "def get_clip_image_embedding(image):\n",
137
- " inputs = clip_processor(images=image, return_tensors=\"pt\").to(device)\n",
138
- " with torch.no_grad():\n",
139
- " features = clip_model.get_image_features(**inputs)\n",
140
- " features = features.cpu().numpy().astype(\"float32\")\n",
141
- " faiss.normalize_L2(features)\n",
142
- " return features"
143
- ]
144
- },
145
- {
146
- "cell_type": "code",
147
- "execution_count": 11,
148
- "id": "ec6399ac-a40d-49f7-9831-3085fca484c9",
149
- "metadata": {},
150
- "outputs": [],
151
- "source": [
152
- "def get_results_with_images(embedding, index, top_k=2):\n",
153
- " D, I = index.search(embedding, top_k)\n",
154
- " results = []\n",
155
- " for idx in I[0]:\n",
156
- " item = data[idx]\n",
157
- " img_id = int(item[\"id\"])\n",
158
- " try:\n",
159
- " img = wikiart_dataset[img_id][\"image\"]\n",
160
- " except IndexError:\n",
161
- " continue\n",
162
- " caption = f\"ID: {item['id']}\\n{item['caption']}\"\n",
163
- " results.append((img, caption))\n",
164
- " return results"
165
- ]
166
- },
167
- {
168
- "cell_type": "code",
169
- "execution_count": 12,
170
- "id": "76adeb1c-85d6-4e53-9c93-a312c21b71b8",
171
- "metadata": {},
172
- "outputs": [],
173
- "source": [
174
- "def search_similar_images(image: Image.Image):\n",
175
- " caption = generate_caption(image)\n",
176
- "\n",
177
- " text_emb = get_clip_text_embedding(caption)\n",
178
- " image_emb = get_clip_image_embedding(image)\n",
179
- "\n",
180
- " text_results = get_results_with_images(text_emb, text_index)\n",
181
- " image_results = get_results_with_images(image_emb, image_index)\n",
182
- "\n",
183
- " return caption, text_results, image_results"
184
- ]
185
- },
186
- {
187
- "cell_type": "code",
188
- "execution_count": 13,
189
- "id": "da86df12-a996-4d1d-ae42-354984cf6dc2",
190
- "metadata": {},
191
- "outputs": [
192
- {
193
- "name": "stdout",
194
- "output_type": "stream",
195
- "text": [
196
- "* Running on local URL: http://127.0.0.1:7862\n",
197
- "* To create a public link, set `share=True` in `launch()`.\n"
198
- ]
199
- },
200
- {
201
- "data": {
202
- "text/html": [
203
- "<div><iframe src=\"http://127.0.0.1:7862/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
204
- ],
205
- "text/plain": [
206
- "<IPython.core.display.HTML object>"
207
- ]
208
- },
209
- "metadata": {},
210
- "output_type": "display_data"
211
- },
212
- {
213
- "data": {
214
- "text/plain": []
215
- },
216
- "execution_count": 13,
217
- "metadata": {},
218
- "output_type": "execute_result"
219
- }
220
- ],
221
- "source": [
222
- "demo = gr.Interface(\n",
223
- " fn=search_similar_images,\n",
224
- " inputs=gr.Image(label=\"Загрузите изображение\", type=\"pil\"),\n",
225
- " outputs=[\n",
226
- " gr.Textbox(label=\"📜 Сгенерированное описание\"),\n",
227
- " gr.Gallery(label=\"🔍 Похожие по описанию (CLIP)\", height=\"auto\", columns=2),\n",
228
- " gr.Gallery(label=\"🎨 Похожие по изображению (CLIP)\", height=\"auto\", columns=2)\n",
229
- " ],\n",
230
- " title=\"🎨 Semantic WikiArt Search (BLIP + CLIP)\",\n",
231
- " description=\"Загрузите изображение. Модель BLIP сгенерирует описание, а CLIP найдёт похожие картины по тексту и изображению.\"\n",
232
- ")\n",
233
- "\n",
234
- "demo.launch()"
235
- ]
236
- },
237
- {
238
- "cell_type": "code",
239
- "execution_count": 14,
240
- "id": "55fbac06-4781-4074-a1e6-26ff758bbfe0",
241
- "metadata": {},
242
- "outputs": [
243
- {
244
- "name": "stdout",
245
- "output_type": "stream",
246
- "text": [
247
- "Rerunning server... use `close()` to stop if you need to change `launch()` parameters.\n",
248
- "----\n"
249
- ]
250
- },
251
- {
252
- "name": "stderr",
253
- "output_type": "stream",
254
- "text": [
255
- "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
256
- "To disable this warning, you can either:\n",
257
- "\t- Avoid using `tokenizers` before the fork if possible\n",
258
- "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
259
- ]
260
- },
261
- {
262
- "name": "stdout",
263
- "output_type": "stream",
264
- "text": [
265
- "* Running on public URL: https://ba46916423948a3a69.gradio.live\n",
266
- "\n",
267
- "This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)\n"
268
- ]
269
- },
270
- {
271
- "data": {
272
- "text/html": [
273
- "<div><iframe src=\"https://ba46916423948a3a69.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
274
- ],
275
- "text/plain": [
276
- "<IPython.core.display.HTML object>"
277
- ]
278
- },
279
- "metadata": {},
280
- "output_type": "display_data"
281
- },
282
- {
283
- "data": {
284
- "text/plain": []
285
- },
286
- "execution_count": 14,
287
- "metadata": {},
288
- "output_type": "execute_result"
289
- }
290
- ],
291
- "source": [
292
- "demo.launch(server_name=\"0.0.0.0\", server_port=7860, share=True)\n"
293
- ]
294
- },
295
- {
296
- "cell_type": "code",
297
- "execution_count": null,
298
- "id": "c44447c3-0709-4419-a6a4-fc451f80702a",
299
- "metadata": {},
300
- "outputs": [],
301
- "source": []
302
- }
303
- ],
304
- "metadata": {
305
- "kernelspec": {
306
- "display_name": "Python 3 (ipykernel)",
307
- "language": "python",
308
- "name": "python3"
309
- },
310
- "language_info": {
311
- "codemirror_mode": {
312
- "name": "ipython",
313
- "version": 3
314
- },
315
- "file_extension": ".py",
316
- "mimetype": "text/x-python",
317
- "name": "python",
318
- "nbconvert_exporter": "python",
319
- "pygments_lexer": "ipython3",
320
- "version": "3.13.3"
321
- }
322
- },
323
- "nbformat": 4,
324
- "nbformat_minor": 5
325
- }
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import torch
4
+ import numpy as np
5
+ import faiss
6
+
7
+ from transformers import (
8
+ BlipProcessor,
9
+ BlipForConditionalGeneration,
10
+ CLIPProcessor,
11
+ CLIPModel
12
+ )
13
+ from datasets import load_dataset
14
+
15
+ # Загрузка датасета и моделей
16
+ wikiart_dataset = load_dataset("huggan/wikiart", split="train")
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
18
+
19
+ blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
20
+ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device).eval()
21
+
22
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device).eval()
23
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
24
+
25
+ # Загрузка FAISS индексов
26
+ image_index = faiss.read_index("image_index.faiss")
27
+ text_index = faiss.read_index("text_index.faiss")
28
+
29
+ # Генерация описания через BLIP
30
+ def generate_caption(image: Image.Image):
31
+ inputs = blip_processor(image, return_tensors="pt").to(device)
32
+ with torch.no_grad():
33
+ caption_ids = blip_model.generate(**inputs)
34
+ caption = blip_processor.decode(caption_ids[0], skip_special_tokens=True)
35
+ return caption
36
+
37
+ # Получение CLIP эмбеддинга по тексту
38
+ def get_clip_text_embedding(text):
39
+ inputs = clip_processor(text=[text], return_tensors="pt", padding=True).to(device)
40
+ with torch.no_grad():
41
+ features = clip_model.get_text_features(**inputs)
42
+ features = features.cpu().numpy().astype("float32")
43
+ faiss.normalize_L2(features)
44
+ return features
45
+
46
+ # Получение CLIP эмбеддинга по изображению
47
+ def get_clip_image_embedding(image):
48
+ inputs = clip_processor(images=image, return_tensors="pt").to(device)
49
+ with torch.no_grad():
50
+ features = clip_model.get_image_features(**inputs)
51
+ features = features.cpu().numpy().astype("float32")
52
+ faiss.normalize_L2(features)
53
+ return features
54
+
55
+ # Получение похожих изображений по эмбеддингу
56
+ def get_results_with_images(embedding, index, top_k=2):
57
+ D, I = index.search(embedding, top_k)
58
+ results = []
59
+ for idx in I[0]:
60
+ try:
61
+ item = wikiart_dataset[idx]
62
+ img = item["image"]
63
+ title = item.get("title", "Untitled")
64
+ artist = item.get("artist", "Unknown")
65
+ caption = f"ID: {idx}\n{title} — {artist}"
66
+ results.append((img, caption))
67
+ except IndexError:
68
+ continue
69
+ return results
70
+
71
+ # Основная функция поиска
72
+ def search_similar_images(image: Image.Image):
73
+ caption = generate_caption(image)
74
+ text_emb = get_clip_text_embedding(caption)
75
+ image_emb = get_clip_image_embedding(image)
76
+
77
+ text_results = get_results_with_images(text_emb, text_index)
78
+ image_results = get_results_with_images(image_emb, image_index)
79
+
80
+ return caption, text_results, image_results
81
+
82
+ # Интерфейс Gradio
83
+ demo = gr.Interface(
84
+ fn=search_similar_images,
85
+ inputs=gr.Image(label="Загрузите изображение", type="pil"),
86
+ outputs=[
87
+ gr.Textbox(label="📜 Сгенерированное описание"),
88
+ gr.Gallery(label="🔍 Похожие по описанию (CLIP)", height="auto", columns=2),
89
+ gr.Gallery(label="🎨 Похожие по изображению (CLIP)", height="auto", columns=2)
90
+ ],
91
+ title="🎨 Semantic WikiArt Search (BLIP + CLIP)",
92
+ description="Загрузите изображение. Модель BLIP сгенерирует описание, а CLIP найдёт похожие картины по тексту и изображению."
93
+ )
94
+
95
+ demo.launch()