semsearch / app.py
hanoch.rahimi@gmail
added openai summarization and visual design
a280e4d
raw
history blame
9.68 kB
import json
from langchain.chains import RetrievalQA
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.prompts import PromptTemplate
from langchain.vectorstores import Pinecone
import openai
import pinecone
import streamlit as st
from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer
from utils import get_companies_data
PINECONE_KEY = st.secrets["PINECONE_API_KEY"] # app.pinecone.io
OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"] # app.pinecone.io
PINE_CONE_ENVIRONMENT = st.secrets["PINE_CONE_ENVIRONMENT"] # app.pinecone.io
model_name = 'text-embedding-ada-002'
embed = OpenAIEmbeddings(
model=model_name,
openai_api_key=OPENAI_API_KEY
)
st.set_page_config(layout="wide")
@st.cache_resource
def init_pinecone():
pinecone.init(api_key=PINECONE_KEY, environment=PINE_CONE_ENVIRONMENT) # get a free api key from app.pinecone.io
return pinecone.Index("dompany-description")
st.session_state.index = init_pinecone()
@st.cache_resource
def init_models():
#retriever = SentenceTransformer("multi-qa-MiniLM-L6-cos-v1")
model_name = "sentence-transformers/all-MiniLM-L6-v2"
retriever = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
#reader = pipeline(tokenizer=model_name, model=model_name, task='question-answering')
tokenizer = AutoTokenizer.from_pretrained(model_name)
#vectorstore = Pinecone(st.session_state.index, embed.embed_query, text_field)
return retriever, tokenizer#, vectorstore
retriever, tokenizer = init_models()
def card(name, description, score, data_type, region, country):
return st.markdown(f"""
<div class="container-fluid">
<div class="row align-items-start" style="padding-bottom:10px;">
<div class="col-md-8 col-sm-8">
<b>{name}.</b>
<span style="">
{description}
</span>
</div>
<div class="col-md-1 col-sm-1">
<span>{region}</span>
</div>
<div class="col-md-1 col-sm-1">
<span>{country}</span>
</div>
<div class="col-md-1 col-sm-1">
<span>{data_type}</span>
<span>[Score: {score}</span>
</div>
</div>
</div>
""", unsafe_allow_html=True)
def index_query(xq, top_k, regions=[], countries=[]):
#st.write(f"Regions: {regions}")
filters = []
if len(regions)>0:
filters.append({'region': {"$in": regions}})
if len(countries)>0:
filters.append({'country': {"$in": countries}})
if len(filters)==1:
filter = filters[0]
elif len(filters)>1:
filter = {"$and": filters}
else:
filter = {}
#st.write(filter)
xc = st.session_state.index.query(xq, top_k=20, filter = filter, include_metadata=True, include_vectors = True)
#xc = st.session_state.index.query(xq, top_k=top_k, include_metadata=True, include_vectors = True)
return xc
def call_openai(prompt, engine="gpt-3.5-turbo", temp=0, top_p=1.0, max_tokens=1024):
try:
response = openai.ChatCompletion.create(
model=engine,
messages=[
{"role": "system", "content": "You are an assistant analyzing startup companies for investments."},
{"role": "user", "content": prompt}
],
temperature=temp,
max_tokens=max_tokens
)
print(response)
text = response.choices[0].message["content"].strip()
return text
except openai.error.OpenAIError as e:
print(f"An error occurred: {str(e)}")
return "Failed to generate a response."
def run_query(query, prompt, scrape_boost, top_k , regions, countries):
xq = retriever.encode([query]).tolist()
try:
xc = index_query(xq, top_k, regions, countries)
except:
# force reload
pinecone.init(api_key=PINECONE_KEY, environment=PINE_CONE_ENVIRONMENT)
st.session_state.index = pinecone.Index("company-description")
xc = index_query(xq, top_k, regions, countries)
results = []
for match in xc['matches']:
#answer = reader(question=query, context=match["metadata"]['context'])
score = match['score']
if 'type' in match['metadata'] and match['metadata']['type']=='description-webcontent':
score = score * scrape_boost
answer = {'score': score}
if match['id'].endswith("_description"):
answer['id'] = match['id'][:-12]
elif match['id'].endswith("_webcontent"):
answer['id'] = match['id'][:-11]
else:
answer['id'] = match['id']
answer["name"] = match["metadata"]['company_name']
answer["description"] = match["metadata"]['description'] if "description" in match['metadata'] else ""
answer["metadata"] = match["metadata"]
results.append(answer)
#Summarize the results
# prompt_txt = """
# You are a venture capitalist analyst. Below are descriptions of startup companies that are relevant to the user with their relevancy score.
# Create a summarized report focusing on the top3 companies.
# For every company find its uniqueness over the other companies. Use only information from the descriptions.
# """
prompt_txt = prompt + """
Company descriptions: {descriptions}
User query: {query}
"""
prompt_template = PromptTemplate(template=prompt_txt, input_variables=["descriptions", "query"])
prompt = prompt_template.format(descriptions = results[:10], query = query)
m_text = call_openai(prompt, engine="gpt-3.5-turbo", temp=0, top_p=1.0, max_tokens=1024)
m_text
sorted_result = sorted(results, key=lambda x: x['score'], reverse=True)
st.markdown("<h2>Related companies</h2>", unsafe_allow_html=True)
#df = get_companies_data([r['id'] for r in results])
for r in sorted_result:
company_name = r["name"]
description = r["description"] #.replace(company_name, f"<mark>{company_name}</mark>")
score = round(r["score"], 4)
data_type = r["metadata"]["type"] if "type" in r["metadata"] else ""
region = r["metadata"]["region"]
country = r["metadata"]["country"]
card(company_name, description, score, data_type, region, country)
def check_password():
"""Returns `True` if the user had the correct password."""
def password_entered():
"""Checks whether a password entered by the user is correct."""
if st.session_state["password"] == st.secrets["password"]:
st.session_state["password_correct"] = True
del st.session_state["password"] # don't store password
else:
st.session_state["password_correct"] = False
if "password_correct" not in st.session_state:
# First run, show input for password.
st.text_input(
"Password", type="password", on_change=password_entered, key="password"
)
return False
elif not st.session_state["password_correct"]:
# Password not correct, show input + error.
st.text_input(
"Password", type="password", on_change=password_entered, key="password"
)
st.error("😕 Password incorrect")
return False
else:
# Password correct.
return True
if check_password():
st.title("")
st.write("""
Search for a company in free text
""")
st.markdown("""
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
""", unsafe_allow_html=True)
with open("data/countries.json", "r") as f:
countries = json.load(f)['countries']
countries_selectbox = st.sidebar.multiselect("Country", countries, default=[])
all_regions = ('Africa', 'Europe', 'Asia & Pacific', 'North America', 'South/Latin America')
region_selectbox = st.sidebar.multiselect("Region", all_regions, default=all_regions)
scrape_boost = st.sidebar.number_input('webcontent_boost', value=2.)
top_k = st.sidebar.number_input('Top K Results', value=20)
# with st.container():
# col1, col2, col3, col4 = st.columns(4)
# with col1:
# scrape_boost = st.number_input('webcontent_boost', value=2.)
# with col2:
# top_k = st.number_input('Top K Results', value=20)
# with col3:
# regions = st.number_input('Region', value=20)
# with col4:
# countries = st.number_input('Country', value=20)
default_prompt = """
summarize the outcome of this search. The context is a list of company names followed by the company's description and a relevance score to the user query.
the report should mention the most important companies and how they compare to each other and contain the following sections:
1) Title: query text (summarized if more than 20 tokens)
2) Best matches: Naming of the 3 companies from the list that are most similar to the search query:
- summarize what they are doing
- name customers and technology if they are mentioned
- compare them to each other and point out what they do differently or what is their unique selling proposition
----"""
prompt = st.text_area("Enter prompt", value=default_prompt)
query = st.text_input("Search!", "")
if query != "":
run_query(query, prompt, scrape_boost, top_k, region_selectbox, countries_selectbox)