Spaces:
Paused
Paused
import pinecone | |
import streamlit as st | |
from transformers import pipeline, AutoTokenizer | |
from sentence_transformers import SentenceTransformer | |
PINECONE_KEY = st.secrets["PINECONE_API_KEY"] # app.pinecone.io | |
PINE_CONE_ENVIRONMENT = st.secrets["PINE_CONE_ENVIRONMENT"] # app.pinecone.io | |
def init_pinecone(): | |
pinecone.init(api_key=PINECONE_KEY, environment=PINE_CONE_ENVIRONMENT) # get a free api key from app.pinecone.io | |
return pinecone.Index("dompany-description") | |
def init_models(): | |
#retriever = SentenceTransformer("multi-qa-MiniLM-L6-cos-v1") | |
model_name = "sentence-transformers/all-MiniLM-L6-v2" | |
retriever = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") | |
reader = pipeline(tokenizer=model_name, model=model_name, task='question-answering') | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
return retriever, reader, tokenizer | |
st.session_state.index = init_pinecone() | |
retriever, reader, tokenizer = init_models() | |
def card(name, description, score): | |
return st.markdown(f""" | |
<div class="container-fluid"> | |
<div class="row align-items-start"> | |
<div class="col-md-12 col-sm-12"> | |
<b>{name}</b> | |
<br> | |
<span style="color: #808080;"> | |
<small>{description}</small> | |
[<b>Score: </b>{score}] | |
</span> | |
</div> | |
</div> | |
</div> | |
""", unsafe_allow_html=True) | |
st.title("") | |
def run_query(query): | |
xq = retriever.encode([query]).tolist() | |
try: | |
xc = st.session_state.index.query(xq, top_k=3, include_metadata=True, include_vectors = True) | |
except: | |
# force reload | |
pinecone.init(api_key=PINECONE_KEY, environment=PINE_CONE_ENVIRONMENT) | |
st.session_state.index = pinecone.Index("company-description") | |
xc = st.session_state.index.query(xq, top_k=10, include_metadata=True, include_vectors = True) | |
results = [] | |
for match in xc['matches']: | |
#answer = reader(question=query, context=match["metadata"]['context']) | |
answer = {'score': match['score']} | |
answer["name"] = match["metadata"]['company_name'].strip('_description') | |
answer["description"] = match["metadata"]['description'] | |
results.append(answer) | |
sorted_result = sorted(results, key=lambda x: x['score'], reverse=True) | |
for r in sorted_result: | |
company_name = r["name"] | |
description = r["description"].replace(company_name, f"<mark>{company_name}</mark>") | |
score = round(r["score"], 4) | |
card(company_name, description, score) | |
def check_password(): | |
"""Returns `True` if the user had the correct password.""" | |
def password_entered(): | |
"""Checks whether a password entered by the user is correct.""" | |
if st.session_state["password"] == st.secrets["password"]: | |
st.session_state["password_correct"] = True | |
del st.session_state["password"] # don't store password | |
else: | |
st.session_state["password_correct"] = False | |
if "password_correct" not in st.session_state: | |
# First run, show input for password. | |
st.text_input( | |
"Password", type="password", on_change=password_entered, key="password" | |
) | |
return False | |
elif not st.session_state["password_correct"]: | |
# Password not correct, show input + error. | |
st.text_input( | |
"Password", type="password", on_change=password_entered, key="password" | |
) | |
st.error("😕 Password incorrect") | |
return False | |
else: | |
# Password correct. | |
return True | |
if check_password(): | |
st.write(""" | |
Search for a company in free text | |
""") | |
st.markdown(""" | |
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous"> | |
""", unsafe_allow_html=True) | |
query = st.text_input("Search!", "") | |
if query != "": | |
run_query(query) | |