LeBuH commited on
Commit
3bda56e
·
verified ·
1 Parent(s): ed1bc19

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -1
app.py CHANGED
@@ -13,6 +13,13 @@ from transformers import (
13
  from datasets import load_dataset
14
 
15
  wikiart_dataset = load_dataset("huggan/wikiart", split="train", streaming=True)
 
 
 
 
 
 
 
16
  device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
17
 
18
  blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
@@ -52,7 +59,7 @@ def get_results_with_images(embedding, index, top_k=2):
52
  results = []
53
  for idx in I[0]:
54
  try:
55
- item = wikiart_dataset[int(idx)]
56
  img = item["image"]
57
  title = item.get("title", "Untitled")
58
  artist = item.get("artist", "Unknown")
 
13
  from datasets import load_dataset
14
 
15
  wikiart_dataset = load_dataset("huggan/wikiart", split="train", streaming=True)
16
+
17
+ def get_item_streaming(dataset, idx):
18
+ for i, item in enumerate(dataset):
19
+ if i == idx:
20
+ return item
21
+ raise IndexError("Index out of range")
22
+
23
  device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
24
 
25
  blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
 
59
  results = []
60
  for idx in I[0]:
61
  try:
62
+ item = get_item_streaming(wikiart_dataset, int(idx))
63
  img = item["image"]
64
  title = item.get("title", "Untitled")
65
  artist = item.get("artist", "Unknown")