LeBuH commited on
Commit
22c8053
·
verified ·
1 Parent(s): 9e182a3

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +325 -0
app.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "id": "78ab80c4-8e25-4464-b710-087d385349fe",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stderr",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "/opt/homebrew/Cellar/jupyterlab/4.4.0/libexec/lib/python3.13/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
14
+ " from .autonotebook import tqdm as notebook_tqdm\n"
15
+ ]
16
+ }
17
+ ],
18
+ "source": [
19
+ "import gradio as gr\n",
20
+ "from PIL import Image\n",
21
+ "import torch\n",
22
+ "import numpy as np\n",
23
+ "import faiss\n",
24
+ "import json\n",
25
+ "\n",
26
+ "from transformers import (\n",
27
+ " BlipProcessor,\n",
28
+ " BlipForConditionalGeneration,\n",
29
+ " CLIPProcessor,\n",
30
+ " CLIPModel\n",
31
+ ")\n",
32
+ "from datasets import load_dataset"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": 3,
38
+ "id": "9e6fe9c1-df25-41ad-ab27-f6fc20ecb956",
39
+ "metadata": {},
40
+ "outputs": [],
41
+ "source": [
42
+ "wikiart_dataset = load_dataset(\"huggan/wikiart\", split=\"train\")\n",
43
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"mps\" if torch.backends.mps.is_available() else \"cpu\")"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "execution_count": 4,
49
+ "id": "b9da3ff0-62e6-4686-af9f-38183f675788",
50
+ "metadata": {},
51
+ "outputs": [
52
+ {
53
+ "name": "stderr",
54
+ "output_type": "stream",
55
+ "text": [
56
+ "Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.\n"
57
+ ]
58
+ }
59
+ ],
60
+ "source": [
61
+ "blip_processor = BlipProcessor.from_pretrained(\"Salesforce/blip-image-captioning-base\")\n",
62
+ "blip_model = BlipForConditionalGeneration.from_pretrained(\"Salesforce/blip-image-captioning-base\").to(device).eval()"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "execution_count": 5,
68
+ "id": "12d9402a-fdbe-4ade-99ed-26f5d7f9ccfd",
69
+ "metadata": {},
70
+ "outputs": [],
71
+ "source": [
72
+ "clip_model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\").to(device).eval()\n",
73
+ "clip_processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")"
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "code",
78
+ "execution_count": 6,
79
+ "id": "d4f5e7b2-c873-4495-8ad1-9e32f4f1fbe1",
80
+ "metadata": {},
81
+ "outputs": [],
82
+ "source": [
83
+ "with open(\"../create_embeddings/wikiart_embeddings.json\", \"r\", encoding=\"utf-8\") as f:\n",
84
+ " data = json.load(f)"
85
+ ]
86
+ },
87
+ {
88
+ "cell_type": "code",
89
+ "execution_count": 7,
90
+ "id": "87bc4121-f316-4769-bf5d-197db30fe2a3",
91
+ "metadata": {},
92
+ "outputs": [],
93
+ "source": [
94
+ "image_index = faiss.read_index(\"../create_index/image_index.faiss\")\n",
95
+ "text_index = faiss.read_index(\"../create_index/text_index.faiss\")"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "code",
100
+ "execution_count": 8,
101
+ "id": "b41d1e5c-d606-4501-a22c-3cde576361d7",
102
+ "metadata": {},
103
+ "outputs": [],
104
+ "source": [
105
+ "def generate_caption(image: Image.Image):\n",
106
+ " inputs = blip_processor(image, return_tensors=\"pt\").to(device)\n",
107
+ " with torch.no_grad():\n",
108
+ " caption_ids = blip_model.generate(**inputs)\n",
109
+ " caption = blip_processor.decode(caption_ids[0], skip_special_tokens=True)\n",
110
+ " return caption"
111
+ ]
112
+ },
113
+ {
114
+ "cell_type": "code",
115
+ "execution_count": 9,
116
+ "id": "263c8672-f4b4-46b7-abf0-483ccfb83c86",
117
+ "metadata": {},
118
+ "outputs": [],
119
+ "source": [
120
+ "def get_clip_text_embedding(text):\n",
121
+ " inputs = clip_processor(text=[text], return_tensors=\"pt\", padding=True).to(device)\n",
122
+ " with torch.no_grad():\n",
123
+ " features = clip_model.get_text_features(**inputs)\n",
124
+ " features = features.cpu().numpy().astype(\"float32\")\n",
125
+ " faiss.normalize_L2(features)\n",
126
+ " return features"
127
+ ]
128
+ },
129
+ {
130
+ "cell_type": "code",
131
+ "execution_count": 10,
132
+ "id": "34827bd8-e0da-4252-b168-3c79f2d2fb02",
133
+ "metadata": {},
134
+ "outputs": [],
135
+ "source": [
136
+ "def get_clip_image_embedding(image):\n",
137
+ " inputs = clip_processor(images=image, return_tensors=\"pt\").to(device)\n",
138
+ " with torch.no_grad():\n",
139
+ " features = clip_model.get_image_features(**inputs)\n",
140
+ " features = features.cpu().numpy().astype(\"float32\")\n",
141
+ " faiss.normalize_L2(features)\n",
142
+ " return features"
143
+ ]
144
+ },
145
+ {
146
+ "cell_type": "code",
147
+ "execution_count": 11,
148
+ "id": "ec6399ac-a40d-49f7-9831-3085fca484c9",
149
+ "metadata": {},
150
+ "outputs": [],
151
+ "source": [
152
+ "def get_results_with_images(embedding, index, top_k=2):\n",
153
+ " D, I = index.search(embedding, top_k)\n",
154
+ " results = []\n",
155
+ " for idx in I[0]:\n",
156
+ " item = data[idx]\n",
157
+ " img_id = int(item[\"id\"])\n",
158
+ " try:\n",
159
+ " img = wikiart_dataset[img_id][\"image\"]\n",
160
+ " except IndexError:\n",
161
+ " continue\n",
162
+ " caption = f\"ID: {item['id']}\\n{item['caption']}\"\n",
163
+ " results.append((img, caption))\n",
164
+ " return results"
165
+ ]
166
+ },
167
+ {
168
+ "cell_type": "code",
169
+ "execution_count": 12,
170
+ "id": "76adeb1c-85d6-4e53-9c93-a312c21b71b8",
171
+ "metadata": {},
172
+ "outputs": [],
173
+ "source": [
174
+ "def search_similar_images(image: Image.Image):\n",
175
+ " caption = generate_caption(image)\n",
176
+ "\n",
177
+ " text_emb = get_clip_text_embedding(caption)\n",
178
+ " image_emb = get_clip_image_embedding(image)\n",
179
+ "\n",
180
+ " text_results = get_results_with_images(text_emb, text_index)\n",
181
+ " image_results = get_results_with_images(image_emb, image_index)\n",
182
+ "\n",
183
+ " return caption, text_results, image_results"
184
+ ]
185
+ },
186
+ {
187
+ "cell_type": "code",
188
+ "execution_count": 13,
189
+ "id": "da86df12-a996-4d1d-ae42-354984cf6dc2",
190
+ "metadata": {},
191
+ "outputs": [
192
+ {
193
+ "name": "stdout",
194
+ "output_type": "stream",
195
+ "text": [
196
+ "* Running on local URL: http://127.0.0.1:7862\n",
197
+ "* To create a public link, set `share=True` in `launch()`.\n"
198
+ ]
199
+ },
200
+ {
201
+ "data": {
202
+ "text/html": [
203
+ "<div><iframe src=\"http://127.0.0.1:7862/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
204
+ ],
205
+ "text/plain": [
206
+ "<IPython.core.display.HTML object>"
207
+ ]
208
+ },
209
+ "metadata": {},
210
+ "output_type": "display_data"
211
+ },
212
+ {
213
+ "data": {
214
+ "text/plain": []
215
+ },
216
+ "execution_count": 13,
217
+ "metadata": {},
218
+ "output_type": "execute_result"
219
+ }
220
+ ],
221
+ "source": [
222
+ "demo = gr.Interface(\n",
223
+ " fn=search_similar_images,\n",
224
+ " inputs=gr.Image(label=\"Загрузите изображение\", type=\"pil\"),\n",
225
+ " outputs=[\n",
226
+ " gr.Textbox(label=\"📜 Сгенерированное описание\"),\n",
227
+ " gr.Gallery(label=\"🔍 Похожие по описанию (CLIP)\", height=\"auto\", columns=2),\n",
228
+ " gr.Gallery(label=\"🎨 Похожие по изображению (CLIP)\", height=\"auto\", columns=2)\n",
229
+ " ],\n",
230
+ " title=\"🎨 Semantic WikiArt Search (BLIP + CLIP)\",\n",
231
+ " description=\"Загрузите изображение. Модель BLIP сгенерирует описание, а CLIP найдёт похожие картины по тексту и изображению.\"\n",
232
+ ")\n",
233
+ "\n",
234
+ "demo.launch()"
235
+ ]
236
+ },
237
+ {
238
+ "cell_type": "code",
239
+ "execution_count": 14,
240
+ "id": "55fbac06-4781-4074-a1e6-26ff758bbfe0",
241
+ "metadata": {},
242
+ "outputs": [
243
+ {
244
+ "name": "stdout",
245
+ "output_type": "stream",
246
+ "text": [
247
+ "Rerunning server... use `close()` to stop if you need to change `launch()` parameters.\n",
248
+ "----\n"
249
+ ]
250
+ },
251
+ {
252
+ "name": "stderr",
253
+ "output_type": "stream",
254
+ "text": [
255
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
256
+ "To disable this warning, you can either:\n",
257
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
258
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
259
+ ]
260
+ },
261
+ {
262
+ "name": "stdout",
263
+ "output_type": "stream",
264
+ "text": [
265
+ "* Running on public URL: https://ba46916423948a3a69.gradio.live\n",
266
+ "\n",
267
+ "This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)\n"
268
+ ]
269
+ },
270
+ {
271
+ "data": {
272
+ "text/html": [
273
+ "<div><iframe src=\"https://ba46916423948a3a69.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
274
+ ],
275
+ "text/plain": [
276
+ "<IPython.core.display.HTML object>"
277
+ ]
278
+ },
279
+ "metadata": {},
280
+ "output_type": "display_data"
281
+ },
282
+ {
283
+ "data": {
284
+ "text/plain": []
285
+ },
286
+ "execution_count": 14,
287
+ "metadata": {},
288
+ "output_type": "execute_result"
289
+ }
290
+ ],
291
+ "source": [
292
+ "demo.launch(server_name=\"0.0.0.0\", server_port=7860, share=True)\n"
293
+ ]
294
+ },
295
+ {
296
+ "cell_type": "code",
297
+ "execution_count": null,
298
+ "id": "c44447c3-0709-4419-a6a4-fc451f80702a",
299
+ "metadata": {},
300
+ "outputs": [],
301
+ "source": []
302
+ }
303
+ ],
304
+ "metadata": {
305
+ "kernelspec": {
306
+ "display_name": "Python 3 (ipykernel)",
307
+ "language": "python",
308
+ "name": "python3"
309
+ },
310
+ "language_info": {
311
+ "codemirror_mode": {
312
+ "name": "ipython",
313
+ "version": 3
314
+ },
315
+ "file_extension": ".py",
316
+ "mimetype": "text/x-python",
317
+ "name": "python",
318
+ "nbconvert_exporter": "python",
319
+ "pygments_lexer": "ipython3",
320
+ "version": "3.13.3"
321
+ }
322
+ },
323
+ "nbformat": 4,
324
+ "nbformat_minor": 5
325
+ }