ArtMatch / app.py
LeBuH's picture
Upload app.py
22c8053 verified
raw
history blame
10.3 kB
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"id": "78ab80c4-8e25-4464-b710-087d385349fe",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/homebrew/Cellar/jupyterlab/4.4.0/libexec/lib/python3.13/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import gradio as gr\n",
"from PIL import Image\n",
"import torch\n",
"import numpy as np\n",
"import faiss\n",
"import json\n",
"\n",
"from transformers import (\n",
" BlipProcessor,\n",
" BlipForConditionalGeneration,\n",
" CLIPProcessor,\n",
" CLIPModel\n",
")\n",
"from datasets import load_dataset"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "9e6fe9c1-df25-41ad-ab27-f6fc20ecb956",
"metadata": {},
"outputs": [],
"source": [
"wikiart_dataset = load_dataset(\"huggan/wikiart\", split=\"train\")\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"mps\" if torch.backends.mps.is_available() else \"cpu\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "b9da3ff0-62e6-4686-af9f-38183f675788",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.\n"
]
}
],
"source": [
"blip_processor = BlipProcessor.from_pretrained(\"Salesforce/blip-image-captioning-base\")\n",
"blip_model = BlipForConditionalGeneration.from_pretrained(\"Salesforce/blip-image-captioning-base\").to(device).eval()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "12d9402a-fdbe-4ade-99ed-26f5d7f9ccfd",
"metadata": {},
"outputs": [],
"source": [
"clip_model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\").to(device).eval()\n",
"clip_processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "d4f5e7b2-c873-4495-8ad1-9e32f4f1fbe1",
"metadata": {},
"outputs": [],
"source": [
"with open(\"../create_embeddings/wikiart_embeddings.json\", \"r\", encoding=\"utf-8\") as f:\n",
" data = json.load(f)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "87bc4121-f316-4769-bf5d-197db30fe2a3",
"metadata": {},
"outputs": [],
"source": [
"image_index = faiss.read_index(\"../create_index/image_index.faiss\")\n",
"text_index = faiss.read_index(\"../create_index/text_index.faiss\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "b41d1e5c-d606-4501-a22c-3cde576361d7",
"metadata": {},
"outputs": [],
"source": [
"def generate_caption(image: Image.Image):\n",
" inputs = blip_processor(image, return_tensors=\"pt\").to(device)\n",
" with torch.no_grad():\n",
" caption_ids = blip_model.generate(**inputs)\n",
" caption = blip_processor.decode(caption_ids[0], skip_special_tokens=True)\n",
" return caption"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "263c8672-f4b4-46b7-abf0-483ccfb83c86",
"metadata": {},
"outputs": [],
"source": [
"def get_clip_text_embedding(text):\n",
" inputs = clip_processor(text=[text], return_tensors=\"pt\", padding=True).to(device)\n",
" with torch.no_grad():\n",
" features = clip_model.get_text_features(**inputs)\n",
" features = features.cpu().numpy().astype(\"float32\")\n",
" faiss.normalize_L2(features)\n",
" return features"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "34827bd8-e0da-4252-b168-3c79f2d2fb02",
"metadata": {},
"outputs": [],
"source": [
"def get_clip_image_embedding(image):\n",
" inputs = clip_processor(images=image, return_tensors=\"pt\").to(device)\n",
" with torch.no_grad():\n",
" features = clip_model.get_image_features(**inputs)\n",
" features = features.cpu().numpy().astype(\"float32\")\n",
" faiss.normalize_L2(features)\n",
" return features"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "ec6399ac-a40d-49f7-9831-3085fca484c9",
"metadata": {},
"outputs": [],
"source": [
"def get_results_with_images(embedding, index, top_k=2):\n",
" D, I = index.search(embedding, top_k)\n",
" results = []\n",
" for idx in I[0]:\n",
" item = data[idx]\n",
" img_id = int(item[\"id\"])\n",
" try:\n",
" img = wikiart_dataset[img_id][\"image\"]\n",
" except IndexError:\n",
" continue\n",
" caption = f\"ID: {item['id']}\\n{item['caption']}\"\n",
" results.append((img, caption))\n",
" return results"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "76adeb1c-85d6-4e53-9c93-a312c21b71b8",
"metadata": {},
"outputs": [],
"source": [
"def search_similar_images(image: Image.Image):\n",
" caption = generate_caption(image)\n",
"\n",
" text_emb = get_clip_text_embedding(caption)\n",
" image_emb = get_clip_image_embedding(image)\n",
"\n",
" text_results = get_results_with_images(text_emb, text_index)\n",
" image_results = get_results_with_images(image_emb, image_index)\n",
"\n",
" return caption, text_results, image_results"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "da86df12-a996-4d1d-ae42-354984cf6dc2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"* Running on local URL: http://127.0.0.1:7862\n",
"* To create a public link, set `share=True` in `launch()`.\n"
]
},
{
"data": {
"text/html": [
"<div><iframe src=\"http://127.0.0.1:7862/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": []
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"demo = gr.Interface(\n",
" fn=search_similar_images,\n",
" inputs=gr.Image(label=\"Загрузите изображение\", type=\"pil\"),\n",
" outputs=[\n",
" gr.Textbox(label=\"📜 Сгенерированное описание\"),\n",
" gr.Gallery(label=\"🔍 Похожие по описанию (CLIP)\", height=\"auto\", columns=2),\n",
" gr.Gallery(label=\"🎨 Похожие по изображению (CLIP)\", height=\"auto\", columns=2)\n",
" ],\n",
" title=\"🎨 Semantic WikiArt Search (BLIP + CLIP)\",\n",
" description=\"Загрузите изображение. Модель BLIP сгенерирует описание, а CLIP найдёт похожие картины по тексту и изображению.\"\n",
")\n",
"\n",
"demo.launch()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "55fbac06-4781-4074-a1e6-26ff758bbfe0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Rerunning server... use `close()` to stop if you need to change `launch()` parameters.\n",
"----\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"* Running on public URL: https://ba46916423948a3a69.gradio.live\n",
"\n",
"This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)\n"
]
},
{
"data": {
"text/html": [
"<div><iframe src=\"https://ba46916423948a3a69.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": []
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"demo.launch(server_name=\"0.0.0.0\", server_port=7860, share=True)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c44447c3-0709-4419-a6a4-fc451f80702a",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}