Spaces:
Paused
Paused
File size: 4,136 Bytes
cd6bb4b 5bf3195 d631df4 5bf3195 c8b8b02 5bf3195 75a550e 5bf3195 c8b8b02 5bf3195 75a550e 5bf3195 c8b8b02 045b802 d631df4 5bf3195 d631df4 5bf3195 d631df4 5bf3195 d631df4 5bf3195 d631df4 5bf3195 d631df4 5bf3195 d631df4 5bf3195 c8b8b02 ac8d16a d631df4 5bf3195 d631df4 5bf3195 d631df4 5bf3195 d631df4 5bf3195 cd6bb4b 5bf3195 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import pinecone
import streamlit as st
from transformers import pipeline, AutoTokenizer
from sentence_transformers import SentenceTransformer
PINECONE_KEY = st.secrets["PINECONE_API_KEY"] # app.pinecone.io
PINE_CONE_ENVIRONMENT = st.secrets["PINE_CONE_ENVIRONMENT"] # app.pinecone.io
@st.cache_resource
def init_pinecone():
pinecone.init(api_key=PINECONE_KEY, environment=PINE_CONE_ENVIRONMENT) # get a free api key from app.pinecone.io
return pinecone.Index("dompany-description")
@st.cache_resource
def init_models():
#retriever = SentenceTransformer("multi-qa-MiniLM-L6-cos-v1")
model_name = "sentence-transformers/all-MiniLM-L6-v2"
retriever = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
reader = pipeline(tokenizer=model_name, model=model_name, task='question-answering')
tokenizer = AutoTokenizer.from_pretrained(model_name)
return retriever, reader, tokenizer
st.session_state.index = init_pinecone()
retriever, reader, tokenizer = init_models()
def card(name, description, score):
return st.markdown(f"""
<div class="container-fluid">
<div class="row align-items-start">
<div class="col-md-12 col-sm-12">
<b>{name}</b>
<br>
<span style="color: #808080;">
<small>{description}</small>
[<b>Score: </b>{score}]
</span>
</div>
</div>
</div>
""", unsafe_allow_html=True)
st.title("")
def run_query(query):
xq = retriever.encode([query]).tolist()
try:
xc = st.session_state.index.query(xq, top_k=3, include_metadata=True, include_vectors = True)
except:
# force reload
pinecone.init(api_key=PINECONE_KEY, environment=PINE_CONE_ENVIRONMENT)
st.session_state.index = pinecone.Index("company-description")
xc = st.session_state.index.query(xq, top_k=10, include_metadata=True, include_vectors = True)
results = []
for match in xc['matches']:
#answer = reader(question=query, context=match["metadata"]['context'])
answer = {'score': match['score']}
answer["name"] = match["metadata"]['company_name'].strip('_description')
answer["description"] = match["metadata"]['description']
results.append(answer)
sorted_result = sorted(results, key=lambda x: x['score'], reverse=True)
for r in sorted_result:
company_name = r["name"]
description = r["description"].replace(company_name, f"<mark>{company_name}</mark>")
score = round(r["score"], 4)
card(company_name, description, score)
def check_password():
"""Returns `True` if the user had the correct password."""
def password_entered():
"""Checks whether a password entered by the user is correct."""
if st.session_state["password"] == st.secrets["password"]:
st.session_state["password_correct"] = True
del st.session_state["password"] # don't store password
else:
st.session_state["password_correct"] = False
if "password_correct" not in st.session_state:
# First run, show input for password.
st.text_input(
"Password", type="password", on_change=password_entered, key="password"
)
return False
elif not st.session_state["password_correct"]:
# Password not correct, show input + error.
st.text_input(
"Password", type="password", on_change=password_entered, key="password"
)
st.error("😕 Password incorrect")
return False
else:
# Password correct.
return True
if check_password():
st.write("""
Search for a company in free text
""")
st.markdown("""
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
""", unsafe_allow_html=True)
query = st.text_input("Search!", "")
if query != "":
run_query(query)
|