Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import pinecone
|
| 3 |
+
from sentence_transformers import SentenceTransformer
|
| 4 |
+
import torch
|
| 5 |
+
from splade.models.transformer_rep import Splade
|
| 6 |
+
from transformers import AutoTokenizer
|
| 7 |
+
from datasets import load_dataset
|
| 8 |
+
|
| 9 |
+
pinecone.init(
|
| 10 |
+
api_key='884344f6-d820-4bc8-9edf-4157373df452',
|
| 11 |
+
environment='gcp-starter'
|
| 12 |
+
)
|
| 13 |
+
index = pinecone.Index('pubmed-splade')
|
| 14 |
+
|
| 15 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 16 |
+
# check device being run on
|
| 17 |
+
if device != 'cuda':
|
| 18 |
+
print("==========\n"+
|
| 19 |
+
"WARNING: You are not running on GPU so this may be slow.\n"+
|
| 20 |
+
"\n==========")
|
| 21 |
+
|
| 22 |
+
dense_model = SentenceTransformer(
|
| 23 |
+
'msmarco-bert-base-dot-v5',
|
| 24 |
+
device=device
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
sparse_model_id = 'naver/splade-cocondenser-ensembledistil'
|
| 28 |
+
sparse_model = Splade(sparse_model_id, agg='max')
|
| 29 |
+
sparse_model.to(device) # move to GPU if possible
|
| 30 |
+
sparse_model.eval()
|
| 31 |
+
|
| 32 |
+
tokenizer = AutoTokenizer.from_pretrained(sparse_model_id)
|
| 33 |
+
data = load_dataset('Binaryy/cream_listings', split='train')
|
| 34 |
+
df = data.to_pandas()
|
| 35 |
+
|
| 36 |
+
def encode(text: str):
|
| 37 |
+
# create dense vec
|
| 38 |
+
dense_vec = dense_model.encode(text).tolist()
|
| 39 |
+
# create sparse vec
|
| 40 |
+
input_ids = tokenizer(text, return_tensors='pt')
|
| 41 |
+
with torch.no_grad():
|
| 42 |
+
sparse_vec = sparse_model(
|
| 43 |
+
d_kwargs=input_ids.to(device)
|
| 44 |
+
)['d_rep'].squeeze()
|
| 45 |
+
# convert to dictionary format
|
| 46 |
+
indices = sparse_vec.nonzero().squeeze().cpu().tolist()
|
| 47 |
+
values = sparse_vec[indices].cpu().tolist()
|
| 48 |
+
sparse_dict = {"indices": indices, "values": values}
|
| 49 |
+
# return vecs
|
| 50 |
+
return dense_vec, sparse_dict
|
| 51 |
+
|
| 52 |
+
def search(query):
|
| 53 |
+
dense, sparse = encode(query)
|
| 54 |
+
# query
|
| 55 |
+
xc = index.query(
|
| 56 |
+
vector=dense,
|
| 57 |
+
sparse_vector=sparse,
|
| 58 |
+
top_k=5, # how many results to return
|
| 59 |
+
include_metadata=True
|
| 60 |
+
)
|
| 61 |
+
match_ids = [match['id'].split('-')[0] for match in xc['matches']]
|
| 62 |
+
# Query the existing DataFrame based on 'id'
|
| 63 |
+
filtered_df = df[df['_id'].isin(match_ids)]
|
| 64 |
+
attributes_to_extract = ['_id', 'postedBy.accountName', 'images', 'title', 'location', 'price']
|
| 65 |
+
extracted_data = filtered_df[attributes_to_extract]
|
| 66 |
+
result_json = extracted_data.to_json(orient='records')
|
| 67 |
+
return result_json
|
| 68 |
+
|
| 69 |
+
# Create a Gradio interface
|
| 70 |
+
iface = gr.Interface(
|
| 71 |
+
fn=search,
|
| 72 |
+
inputs="text",
|
| 73 |
+
outputs="json",
|
| 74 |
+
title="Semantic Search Prototype",
|
| 75 |
+
description="Enter your query to perform a semantic search.",
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# Launch the Gradio interface
|
| 79 |
+
iface.launch(share=True)
|