Spaces:
Runtime error
Runtime error
| import clip | |
| import pickle | |
| import requests | |
| from PIL import Image | |
| import numpy as np | |
| is_gpu = False | |
| device = CUDA(0) if is_gpu else "cpu" | |
| from datasets import load_dataset | |
| dataset = load_dataset("jamescalam/unsplash-25k-photos", split="train") | |
| emb_filename = 'unsplash-25k-photos-embeddings-indexes.pkl' | |
| with open(emb_filename, 'rb') as emb: | |
| id2url, img_names, img_emb = pickle.load(emb) | |
| orig_clip_model, orig_clip_processor = clip.load("ViT-B/32", device=device, jit=False) | |
| def search(search_query): | |
| with torch.no_grad(): | |
| # Encode and normalize the description using CLIP | |
| text_encoded = orig_clip_model.encode_text(clip.tokenize(search_query)) | |
| text_encoded /= text_encoded.norm(dim=-1, keepdim=True) | |
| # Retrieve the description vector | |
| text_features = text_encoded.cpu().numpy() | |
| # Compute the similarity between the descrption and each photo using the Cosine similarity | |
| similarities = (text_features @ img_emb.T).squeeze(0) | |
| # Sort the photos by their similarity score | |
| best_photos = similarities.argsort()[::-1] | |
| best_photos = best_photos[:15] | |
| #best_photos = sorted(zip(similarities, range(img_emb.shape[0])), key=lambda x: x[0], reverse=True) | |
| best_photo_ids = img_names[best_photos] | |
| imgs = [] | |
| # Iterate over the top 5 results | |
| for id in best_photo_ids: | |
| id, _ = id.split('.') | |
| url = id2url.get(id, "") | |
| if url == "": continue | |
| r = requests.get(url + "?w=512", stream=True) | |
| img = Image.open(r.raw) | |
| #credits = f'Photo by <a href="https://unsplash.com/@{photo["photographer_username"]}?utm_source=NaturalLanguageImageSearch&utm_medium=referral">{photo["photographer_first_name"]} {photo["photographer_last_name"]}</a> on <a href="https://unsplash.com/?utm_source=NaturalLanguageImageSearch&utm_medium=referral">Unsplash</a>' | |
| imgs.append(img) | |
| #display(HTML(f'Photo by <a href="https://unsplash.com/@{photo["photographer_username"]}?utm_source=NaturalLanguageImageSearch&utm_medium=referral">{photo["photographer_first_name"]} {photo["photographer_last_name"]}</a> on <a href="https://unsplash.com/?utm_source=NaturalLanguageImageSearch&utm_medium=referral">Unsplash</a>')) | |
| if len(imgs) == 5: break | |
| return imgs | |
| with gr.Blocks() as demo: | |
| with gr.Column(variant="panel"): | |
| with gr.Row(variant="compact"): | |
| text = gr.Textbox( | |
| label="Enter your prompt", | |
| show_label=False, | |
| max_lines=1, | |
| placeholder="Enter your prompt", | |
| ).style( | |
| container=False, | |
| ) | |
| search_btn = gr.Button("Search for images").style(full_width=False) | |
| gallery = gr.Gallery( | |
| label="Generated images", show_label=False, elem_id="gallery" | |
| ).style(grid=[3,3,5], height="auto") | |
| search_btn.click(search, text, gallery) | |
| demo.launch() |