File size: 9,683 Bytes
19e6802
eab6925
 
 
 
 
 
19e6802
5bf3195
eab6925
5bf3195
 
a280e4d
eab6925
c8b8b02
eab6925
c8b8b02
5bf3195
eab6925
 
 
 
 
 
 
 
19e6802
 
75a550e
5bf3195
c8b8b02
 
eab6925
 
5bf3195
75a550e
5bf3195
c8b8b02
045b802
d631df4
eab6925
d631df4
eab6925
 
5bf3195
eab6925
5bf3195
 
19e6802
5bf3195
 
a280e4d
0b0b1a7
a280e4d
 
 
5bf3195
 
19e6802
a280e4d
19e6802
 
a280e4d
19e6802
 
a280e4d
 
0b0b1a7
 
5bf3195
 
 
 
19e6802
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eab6925
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5bf3195
 
19e6802
5bf3195
 
c8b8b02
ac8d16a
19e6802
5bf3195
 
 
d631df4
19e6802
 
 
 
a280e4d
 
 
 
 
 
 
19e6802
0b0b1a7
5bf3195
 
eab6925
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5bf3195
 
a280e4d
 
 
 
5bf3195
d631df4
a280e4d
5bf3195
19e6802
 
 
 
5bf3195
cd6bb4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ac1abc
19e6802
 
cd6bb4b
 
 
 
 
 
 
19e6802
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eab6925
 
 
 
 
 
 
 
 
 
 
 
19e6802
cd6bb4b
 
 
eab6925
5bf3195
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
import json

from langchain.chains import RetrievalQA
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.prompts import PromptTemplate
from langchain.vectorstores import Pinecone
import openai
import pinecone
import streamlit as st
from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer

from utils import get_companies_data

PINECONE_KEY = st.secrets["PINECONE_API_KEY"]  # app.pinecone.io
OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"]  # app.pinecone.io
PINE_CONE_ENVIRONMENT = st.secrets["PINE_CONE_ENVIRONMENT"]  # app.pinecone.io

model_name = 'text-embedding-ada-002'

embed = OpenAIEmbeddings(
    model=model_name,
    openai_api_key=OPENAI_API_KEY
)


st.set_page_config(layout="wide")

@st.cache_resource
def init_pinecone():
    pinecone.init(api_key=PINECONE_KEY, environment=PINE_CONE_ENVIRONMENT)  # get a free api key from app.pinecone.io
    return pinecone.Index("dompany-description")

st.session_state.index = init_pinecone()
    
@st.cache_resource
def init_models():
    #retriever = SentenceTransformer("multi-qa-MiniLM-L6-cos-v1")
    model_name = "sentence-transformers/all-MiniLM-L6-v2"
    retriever = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
    #reader = pipeline(tokenizer=model_name, model=model_name, task='question-answering')
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    #vectorstore = Pinecone(st.session_state.index, embed.embed_query, text_field)    
    return retriever, tokenizer#, vectorstore

retriever, tokenizer = init_models()


def card(name, description, score, data_type, region, country):
    return st.markdown(f"""
    <div class="container-fluid">
        <div class="row align-items-start" style="padding-bottom:10px;">
             <div  class="col-md-8 col-sm-8">
                 <b>{name}.</b>
                 <span style="">
                     {description}
                 </span>
             </div>
             <div  class="col-md-1 col-sm-1">
                    <span>{region}</span>
             </div>        
             <div  class="col-md-1 col-sm-1">
                    <span>{country}</span>
             </div>        
             <div  class="col-md-1 col-sm-1">
                    <span>{data_type}</span>
                    <span>[Score: {score}</span>
             </div>        
         </div>
     </div>
        """, unsafe_allow_html=True)


def index_query(xq, top_k, regions=[], countries=[]):
    #st.write(f"Regions: {regions}")
    filters = []
    if len(regions)>0:
        filters.append({'region': {"$in": regions}})
    if len(countries)>0:
        filters.append({'country': {"$in": countries}})
    if len(filters)==1:
        filter = filters[0]
    elif len(filters)>1:
        filter = {"$and": filters} 
    else:
        filter = {}
    #st.write(filter)
    xc = st.session_state.index.query(xq, top_k=20, filter = filter, include_metadata=True, include_vectors = True)
    #xc = st.session_state.index.query(xq, top_k=top_k, include_metadata=True, include_vectors = True)
    return xc

def call_openai(prompt, engine="gpt-3.5-turbo", temp=0, top_p=1.0, max_tokens=1024):
    try:
        response = openai.ChatCompletion.create(
            model=engine,
            messages=[
                {"role": "system", "content": "You are an assistant analyzing startup companies for investments."},
                {"role": "user", "content": prompt}
            ],
            temperature=temp,
            max_tokens=max_tokens
        )
        print(response)
        text = response.choices[0].message["content"].strip()
        return text
    except openai.error.OpenAIError as e:
        print(f"An error occurred: {str(e)}")
    return "Failed to generate a response."



def run_query(query, prompt, scrape_boost, top_k , regions, countries):
    xq = retriever.encode([query]).tolist()
    try:
        xc = index_query(xq, top_k, regions, countries)
    except:
        # force reload
        pinecone.init(api_key=PINECONE_KEY, environment=PINE_CONE_ENVIRONMENT)
        st.session_state.index = pinecone.Index("company-description")
        xc = index_query(xq, top_k, regions, countries)

    results = []
    for match in xc['matches']:
        #answer = reader(question=query, context=match["metadata"]['context'])
        score = match['score']
        if 'type' in match['metadata'] and match['metadata']['type']=='description-webcontent':
            score = score * scrape_boost
        answer = {'score': score}
        if match['id'].endswith("_description"):
            answer['id'] = match['id'][:-12]
        elif match['id'].endswith("_webcontent"):
            answer['id'] = match['id'][:-11]
        else:
            answer['id'] = match['id']
        answer["name"] = match["metadata"]['company_name']
        answer["description"] = match["metadata"]['description'] if "description" in match['metadata'] else ""
        answer["metadata"] = match["metadata"]
        results.append(answer)


    #Summarize the results
    # prompt_txt = """
    # You are a venture capitalist analyst. Below are descriptions of startup companies that are relevant to the user with their relevancy score. 
    # Create a summarized report focusing on the top3 companies.
    # For every company find its uniqueness over the other companies. Use only information from the descriptions.
    # """
    prompt_txt = prompt + """        
    Company descriptions: {descriptions}
    User query: {query} 
    """    
    prompt_template = PromptTemplate(template=prompt_txt, input_variables=["descriptions", "query"])
    prompt = prompt_template.format(descriptions = results[:10], query = query)
    m_text = call_openai(prompt, engine="gpt-3.5-turbo", temp=0, top_p=1.0, max_tokens=1024)

    m_text
    
    sorted_result = sorted(results, key=lambda x: x['score'], reverse=True)


    st.markdown("<h2>Related companies</h2>", unsafe_allow_html=True)
    #df = get_companies_data([r['id'] for r in results])

    for r in sorted_result:
        company_name = r["name"]
        description = r["description"]  #.replace(company_name, f"<mark>{company_name}</mark>")
        score = round(r["score"], 4)
        data_type = r["metadata"]["type"] if "type" in r["metadata"] else ""
        region = r["metadata"]["region"]
        country = r["metadata"]["country"]
        card(company_name, description, score, data_type, region, country)

def check_password():
    """Returns `True` if the user had the correct password."""

    def password_entered():
        """Checks whether a password entered by the user is correct."""
        if st.session_state["password"] == st.secrets["password"]:
            st.session_state["password_correct"] = True
            del st.session_state["password"]  # don't store password
        else:
            st.session_state["password_correct"] = False

    if "password_correct" not in st.session_state:
        # First run, show input for password.
        st.text_input(
            "Password", type="password", on_change=password_entered, key="password"
        )
        return False
    elif not st.session_state["password_correct"]:
        # Password not correct, show input + error.
        st.text_input(
            "Password", type="password", on_change=password_entered, key="password"
        )
        st.error("😕 Password incorrect")
        return False
    else:
        # Password correct.
        return True

if check_password():
    st.title("")

    st.write("""
    Search for a company in free text
    """)

    st.markdown("""
    <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
    """, unsafe_allow_html=True)
    with open("data/countries.json", "r") as f:
        countries = json.load(f)['countries']
    countries_selectbox = st.sidebar.multiselect("Country", countries, default=[])
    all_regions = ('Africa', 'Europe', 'Asia & Pacific', 'North America', 'South/Latin America')
    region_selectbox = st.sidebar.multiselect("Region", all_regions, default=all_regions)
    scrape_boost = st.sidebar.number_input('webcontent_boost', value=2.)
    top_k = st.sidebar.number_input('Top K Results', value=20)

    # with st.container():
        # col1, col2, col3, col4 = st.columns(4)
        # with col1:
            # scrape_boost = st.number_input('webcontent_boost', value=2.)
        # with col2:
            # top_k = st.number_input('Top K Results', value=20)
        # with col3:
            # regions = st.number_input('Region', value=20)
        # with col4:
            # countries = st.number_input('Country', value=20)
    default_prompt = """
summarize the outcome of this search. The context is a list of company names followed by the company's description and a relevance score to the user query. 
the report should mention the most important companies and how they compare to each other and contain the following sections:
1) Title: query text (summarized if more than 20 tokens)
2) Best matches: Naming of the 3 companies from the list that are most similar to the search query:
- summarize what they are doing
- name customers and technology if they are mentioned
- compare them to each other and point out what they do differently or what is their unique selling proposition
----"""    

    prompt = st.text_area("Enter prompt", value=default_prompt)


    query = st.text_input("Search!", "")

    if query != "":
        run_query(query, prompt, scrape_boost, top_k, region_selectbox, countries_selectbox)