Spaces:
Paused
Paused
import json | |
from langchain.chains import RetrievalQA | |
from langchain.embeddings.openai import OpenAIEmbeddings | |
from langchain.prompts import PromptTemplate | |
from langchain.vectorstores import Pinecone | |
import openai | |
import pinecone | |
import streamlit as st | |
from transformers import AutoTokenizer | |
from sentence_transformers import SentenceTransformer | |
from utils import get_companies_data | |
PINECONE_KEY = st.secrets["PINECONE_API_KEY"] # app.pinecone.io | |
OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"] # app.pinecone.io | |
PINE_CONE_ENVIRONMENT = st.secrets["PINE_CONE_ENVIRONMENT"] # app.pinecone.io | |
model_name = 'text-embedding-ada-002' | |
embed = OpenAIEmbeddings( | |
model=model_name, | |
openai_api_key=OPENAI_API_KEY | |
) | |
st.set_page_config(layout="wide") | |
def init_pinecone(): | |
pinecone.init(api_key=PINECONE_KEY, environment=PINE_CONE_ENVIRONMENT) # get a free api key from app.pinecone.io | |
return pinecone.Index("dompany-description") | |
st.session_state.index = init_pinecone() | |
def init_models(): | |
#retriever = SentenceTransformer("multi-qa-MiniLM-L6-cos-v1") | |
model_name = "sentence-transformers/all-MiniLM-L6-v2" | |
retriever = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") | |
#reader = pipeline(tokenizer=model_name, model=model_name, task='question-answering') | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
#vectorstore = Pinecone(st.session_state.index, embed.embed_query, text_field) | |
return retriever, tokenizer#, vectorstore | |
retriever, tokenizer = init_models() | |
def card(name, description, score, data_type, region, country): | |
return st.markdown(f""" | |
<div class="container-fluid"> | |
<div class="row align-items-start" style="padding-bottom:10px;"> | |
<div class="col-md-8 col-sm-8"> | |
<b>{name}.</b> | |
<span style=""> | |
{description} | |
</span> | |
</div> | |
<div class="col-md-1 col-sm-1"> | |
<span>{region}</span> | |
</div> | |
<div class="col-md-1 col-sm-1"> | |
<span>{country}</span> | |
</div> | |
<div class="col-md-1 col-sm-1"> | |
<span>{data_type}</span> | |
<span>[Score: {score}</span> | |
</div> | |
</div> | |
</div> | |
""", unsafe_allow_html=True) | |
def index_query(xq, top_k, regions=[], countries=[]): | |
#st.write(f"Regions: {regions}") | |
filters = [] | |
if len(regions)>0: | |
filters.append({'region': {"$in": regions}}) | |
if len(countries)>0: | |
filters.append({'country': {"$in": countries}}) | |
if len(filters)==1: | |
filter = filters[0] | |
elif len(filters)>1: | |
filter = {"$and": filters} | |
else: | |
filter = {} | |
#st.write(filter) | |
xc = st.session_state.index.query(xq, top_k=20, filter = filter, include_metadata=True, include_vectors = True) | |
#xc = st.session_state.index.query(xq, top_k=top_k, include_metadata=True, include_vectors = True) | |
return xc | |
def call_openai(prompt, engine="gpt-3.5-turbo", temp=0, top_p=1.0, max_tokens=1024): | |
try: | |
response = openai.ChatCompletion.create( | |
model=engine, | |
messages=[ | |
{"role": "system", "content": "You are an assistant analyzing startup companies for investments."}, | |
{"role": "user", "content": prompt} | |
], | |
temperature=temp, | |
max_tokens=max_tokens | |
) | |
print(response) | |
text = response.choices[0].message["content"].strip() | |
return text | |
except openai.error.OpenAIError as e: | |
print(f"An error occurred: {str(e)}") | |
return "Failed to generate a response." | |
def run_query(query, prompt, scrape_boost, top_k , regions, countries): | |
xq = retriever.encode([query]).tolist() | |
try: | |
xc = index_query(xq, top_k, regions, countries) | |
except: | |
# force reload | |
pinecone.init(api_key=PINECONE_KEY, environment=PINE_CONE_ENVIRONMENT) | |
st.session_state.index = pinecone.Index("company-description") | |
xc = index_query(xq, top_k, regions, countries) | |
results = [] | |
for match in xc['matches']: | |
#answer = reader(question=query, context=match["metadata"]['context']) | |
score = match['score'] | |
if 'type' in match['metadata'] and match['metadata']['type']=='description-webcontent': | |
score = score * scrape_boost | |
answer = {'score': score} | |
if match['id'].endswith("_description"): | |
answer['id'] = match['id'][:-12] | |
elif match['id'].endswith("_webcontent"): | |
answer['id'] = match['id'][:-11] | |
else: | |
answer['id'] = match['id'] | |
answer["name"] = match["metadata"]['company_name'] | |
answer["description"] = match["metadata"]['description'] if "description" in match['metadata'] else "" | |
answer["metadata"] = match["metadata"] | |
results.append(answer) | |
#Summarize the results | |
# prompt_txt = """ | |
# You are a venture capitalist analyst. Below are descriptions of startup companies that are relevant to the user with their relevancy score. | |
# Create a summarized report focusing on the top3 companies. | |
# For every company find its uniqueness over the other companies. Use only information from the descriptions. | |
# """ | |
prompt_txt = prompt + """ | |
Company descriptions: {descriptions} | |
User query: {query} | |
""" | |
prompt_template = PromptTemplate(template=prompt_txt, input_variables=["descriptions", "query"]) | |
prompt = prompt_template.format(descriptions = results[:10], query = query) | |
m_text = call_openai(prompt, engine="gpt-3.5-turbo", temp=0, top_p=1.0, max_tokens=1024) | |
m_text | |
sorted_result = sorted(results, key=lambda x: x['score'], reverse=True) | |
st.markdown("<h2>Related companies</h2>", unsafe_allow_html=True) | |
#df = get_companies_data([r['id'] for r in results]) | |
for r in sorted_result: | |
company_name = r["name"] | |
description = r["description"] #.replace(company_name, f"<mark>{company_name}</mark>") | |
score = round(r["score"], 4) | |
data_type = r["metadata"]["type"] if "type" in r["metadata"] else "" | |
region = r["metadata"]["region"] | |
country = r["metadata"]["country"] | |
card(company_name, description, score, data_type, region, country) | |
def check_password(): | |
"""Returns `True` if the user had the correct password.""" | |
def password_entered(): | |
"""Checks whether a password entered by the user is correct.""" | |
if st.session_state["password"] == st.secrets["password"]: | |
st.session_state["password_correct"] = True | |
del st.session_state["password"] # don't store password | |
else: | |
st.session_state["password_correct"] = False | |
if "password_correct" not in st.session_state: | |
# First run, show input for password. | |
st.text_input( | |
"Password", type="password", on_change=password_entered, key="password" | |
) | |
return False | |
elif not st.session_state["password_correct"]: | |
# Password not correct, show input + error. | |
st.text_input( | |
"Password", type="password", on_change=password_entered, key="password" | |
) | |
st.error("😕 Password incorrect") | |
return False | |
else: | |
# Password correct. | |
return True | |
if check_password(): | |
st.title("") | |
st.write(""" | |
Search for a company in free text | |
""") | |
st.markdown(""" | |
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous"> | |
""", unsafe_allow_html=True) | |
with open("data/countries.json", "r") as f: | |
countries = json.load(f)['countries'] | |
countries_selectbox = st.sidebar.multiselect("Country", countries, default=[]) | |
all_regions = ('Africa', 'Europe', 'Asia & Pacific', 'North America', 'South/Latin America') | |
region_selectbox = st.sidebar.multiselect("Region", all_regions, default=all_regions) | |
scrape_boost = st.sidebar.number_input('webcontent_boost', value=2.) | |
top_k = st.sidebar.number_input('Top K Results', value=20) | |
# with st.container(): | |
# col1, col2, col3, col4 = st.columns(4) | |
# with col1: | |
# scrape_boost = st.number_input('webcontent_boost', value=2.) | |
# with col2: | |
# top_k = st.number_input('Top K Results', value=20) | |
# with col3: | |
# regions = st.number_input('Region', value=20) | |
# with col4: | |
# countries = st.number_input('Country', value=20) | |
default_prompt = """ | |
summarize the outcome of this search. The context is a list of company names followed by the company's description and a relevance score to the user query. | |
the report should mention the most important companies and how they compare to each other and contain the following sections: | |
1) Title: query text (summarized if more than 20 tokens) | |
2) Best matches: Naming of the 3 companies from the list that are most similar to the search query: | |
- summarize what they are doing | |
- name customers and technology if they are mentioned | |
- compare them to each other and point out what they do differently or what is their unique selling proposition | |
----""" | |
prompt = st.text_area("Enter prompt", value=default_prompt) | |
query = st.text_input("Search!", "") | |
if query != "": | |
run_query(query, prompt, scrape_boost, top_k, region_selectbox, countries_selectbox) | |