LeBuH commited on
Commit
6283c79
·
verified ·
1 Parent(s): c6ad3ab

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -8
app.py CHANGED
@@ -12,7 +12,6 @@ from transformers import (
12
  )
13
  from datasets import load_dataset
14
 
15
- # Загрузка датасета и моделей
16
  wikiart_dataset = load_dataset("huggan/wikiart", split="train")
17
  device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
18
 
@@ -22,11 +21,9 @@ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image
22
  clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device).eval()
23
  clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
24
 
25
- # Загрузка FAISS индексов
26
  image_index = faiss.read_index("image_index.faiss")
27
  text_index = faiss.read_index("text_index.faiss")
28
 
29
- # Генерация описания через BLIP
30
  def generate_caption(image: Image.Image):
31
  inputs = blip_processor(image, return_tensors="pt").to(device)
32
  with torch.no_grad():
@@ -34,7 +31,6 @@ def generate_caption(image: Image.Image):
34
  caption = blip_processor.decode(caption_ids[0], skip_special_tokens=True)
35
  return caption
36
 
37
- # Получение CLIP эмбеддинга по тексту
38
  def get_clip_text_embedding(text):
39
  inputs = clip_processor(text=[text], return_tensors="pt", padding=True).to(device)
40
  with torch.no_grad():
@@ -43,7 +39,6 @@ def get_clip_text_embedding(text):
43
  faiss.normalize_L2(features)
44
  return features
45
 
46
- # Получение CLIP эмбеддинга по изображению
47
  def get_clip_image_embedding(image):
48
  inputs = clip_processor(images=image, return_tensors="pt").to(device)
49
  with torch.no_grad():
@@ -52,13 +47,12 @@ def get_clip_image_embedding(image):
52
  faiss.normalize_L2(features)
53
  return features
54
 
55
- # Получение похожих изображений по эмбеддингу
56
  def get_results_with_images(embedding, index, top_k=2):
57
  D, I = index.search(embedding, top_k)
58
  results = []
59
  for idx in I[0]:
60
  try:
61
- item = wikiart_dataset[idx]
62
  img = item["image"]
63
  title = item.get("title", "Untitled")
64
  artist = item.get("artist", "Unknown")
@@ -79,7 +73,6 @@ def search_similar_images(image: Image.Image):
79
 
80
  return caption, text_results, image_results
81
 
82
- # Интерфейс Gradio
83
  demo = gr.Interface(
84
  fn=search_similar_images,
85
  inputs=gr.Image(label="Загрузите изображение", type="pil"),
 
12
  )
13
  from datasets import load_dataset
14
 
 
15
  wikiart_dataset = load_dataset("huggan/wikiart", split="train")
16
  device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
17
 
 
21
  clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device).eval()
22
  clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
23
 
 
24
  image_index = faiss.read_index("image_index.faiss")
25
  text_index = faiss.read_index("text_index.faiss")
26
 
 
27
  def generate_caption(image: Image.Image):
28
  inputs = blip_processor(image, return_tensors="pt").to(device)
29
  with torch.no_grad():
 
31
  caption = blip_processor.decode(caption_ids[0], skip_special_tokens=True)
32
  return caption
33
 
 
34
  def get_clip_text_embedding(text):
35
  inputs = clip_processor(text=[text], return_tensors="pt", padding=True).to(device)
36
  with torch.no_grad():
 
39
  faiss.normalize_L2(features)
40
  return features
41
 
 
42
  def get_clip_image_embedding(image):
43
  inputs = clip_processor(images=image, return_tensors="pt").to(device)
44
  with torch.no_grad():
 
47
  faiss.normalize_L2(features)
48
  return features
49
 
 
50
  def get_results_with_images(embedding, index, top_k=2):
51
  D, I = index.search(embedding, top_k)
52
  results = []
53
  for idx in I[0]:
54
  try:
55
+ item = wikiart_dataset[int(idx)]
56
  img = item["image"]
57
  title = item.get("title", "Untitled")
58
  artist = item.get("artist", "Unknown")
 
73
 
74
  return caption, text_results, image_results
75
 
 
76
  demo = gr.Interface(
77
  fn=search_similar_images,
78
  inputs=gr.Image(label="Загрузите изображение", type="pil"),