hanoch.rahimi@gmail commited on
Commit
6a2ae7a
·
1 Parent(s): adb5688

initial conversation

Browse files
Files changed (2) hide show
  1. app.py +111 -112
  2. utils.py +77 -8
app.py CHANGED
@@ -10,6 +10,8 @@ import streamlit as st
10
  from transformers import AutoTokenizer
11
  from sentence_transformers import SentenceTransformer
12
 
 
 
13
  import utils
14
 
15
  PINECONE_KEY = st.secrets["PINECONE_API_KEY"] # app.pinecone.io
@@ -43,6 +45,7 @@ def init_models():
43
  return retriever, tokenizer#, vectorstore
44
 
45
  retriever, tokenizer = init_models()
 
46
 
47
 
48
  def card(company_id, name, description, score, data_type, region, country, metadata, is_debug):
@@ -59,41 +62,31 @@ def card(company_id, name, description, score, data_type, region, country, metad
59
  except Exception as e:
60
  print(f"An error occurred: {str(e)}")
61
 
62
-
63
-
64
-
65
  markdown = f"""
66
  <div class="row align-items-start" style="padding-bottom:10px;">
67
  <div class="col-md-8 col-sm-8">
68
  <b>{name} (<a href='https://{company_id}'>website</a>).</b>
69
- <p style="">
70
- {description}
71
- </p>
72
- </div>
73
- <div class="col-md-1 col-sm-1">
74
- <span>{country}</span>
75
- </div>
76
- <div class="col-md-1 col-sm-1">
77
- <span>{customer_problem}</span>
78
- </div>
79
- <div class="col-md-1 col-sm-1">
80
- <span>{business_model}</span>
81
- </div>
82
- <div class="col-md-1 col-sm-1">
83
- <button type='button' onclick="like_company({company_id});">Like</button>
84
- <button type='button' onclick="dislike_company({company_id});">DisLike</button>
85
  </div>
 
 
 
86
  """
87
 
88
  if is_debug:
89
  markdown = markdown + f"""
 
 
 
 
90
  <div class="col-md-1 col-sm-1">
91
  <span>{data_type}</span>
92
  <span>[Score: {score}</span>
93
  </div>
94
  """
95
  markdown = markdown + "</div>"
96
- return st.markdown(markdown, unsafe_allow_html=True)
 
97
 
98
 
99
  def index_query(xq, top_k, regions=[], countries=[], index_namespace="websummarized"):
@@ -114,32 +107,8 @@ def index_query(xq, top_k, regions=[], countries=[], index_namespace="websummari
114
  #xc = st.session_state.index.query(xq, top_k=top_k, include_metadata=True, include_vectors = True)
115
  return xc
116
 
117
- def call_openai(prompt, engine="gpt-3.5-turbo", temp=0, top_p=1.0, max_tokens=4048):
118
- try:
119
- response = openai.ChatCompletion.create(
120
- model=engine,
121
- messages=[{"role": "user", "content": prompt}],
122
- temperature=temp,
123
- max_tokens=max_tokens
124
- )
125
- print(response)
126
- text = response.choices[0].message["content"].strip()
127
- return text
128
- except openai.error.OpenAIError as e:
129
- print(f"An error occurred: {str(e)}")
130
- return "Failed to generate a response."
131
-
132
- def on_prompt_selected():
133
- title = st.session_state.advanced_prompts_select
134
- new_prompt = utils.get_prompt(title)
135
- if len(new_prompt)>0 and len(new_prompt[0])>0:
136
- print(f"Got a prompt for title {title}\n {new_prompt[0]}")
137
- st.session_state.prompt_title_editable = st.session_state.advanced_prompts_select
138
- st.session_state.advanced_prompt_content = new_prompt[0]
139
- else:
140
- print(f"No results for title {st.session_state.advanced_prompts_select}")
141
 
142
- def run_query(query, prompt, scrape_boost, top_k , regions, countries, is_debug, index_namespace):
143
  xq = retriever.encode([query]).tolist()
144
  try:
145
  xc = index_query(xq, top_k, regions, countries)
@@ -182,44 +151,61 @@ def run_query(query, prompt, scrape_boost, top_k , regions, countries, is_debug,
182
  # Create a summarized report focusing on the top3 companies.
183
  # For every company find its uniqueness over the other companies. Use only information from the descriptions.
184
  # """
185
- prompt_txt = prompt + """
186
- Company descriptions: {descriptions}
187
- User query: {query}
188
- """
189
- prompt_template = PromptTemplate(template=prompt_txt, input_variables=["descriptions", "query"])
190
- descriptions = str([f"{res['name']}: {res['data']['Summary']}" for res in results[:20] if 'Summary' in res['data']])
191
- ntokens = len(descriptions.split(" "))
192
- print(f"#Tokens {ntokens}:\n {descriptions}")
193
- prompt = prompt_template.format(descriptions = descriptions, query = query)
194
- m_text = call_openai(prompt, engine="gpt-3.5-turbo-16k", temp=0, top_p=1.0)
195
-
196
- m_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
  sorted_results = sorted(results, key=lambda x: x['score'], reverse=True)
199
 
200
- st.markdown("<h2>Related companies</h2>", unsafe_allow_html=True)
201
-
202
  names = []
203
- st.markdown("""
204
- <div class="container-fluid">
205
- <div class="row align-items-start" style="padding-bottom:10px;">
206
- <div class="col-md-8 col-sm-8">
207
- <span>Company</span>
208
- </div>
209
- <div class="col-md-1 col-sm-1">
210
- <span>Country</span>
211
- </div>
212
- <div class="col-md-1 col-sm-1">
213
- <span>Customer Problem</span>
214
- </div>
215
- <div class="col-md-1 col-sm-1">
216
- <span>Business Model</span>
217
- </div>
218
- <div class="col-md-1 col-sm-1">
219
- Actions
220
- </div>
221
- </div>
222
- """, unsafe_allow_html=True)
 
 
 
223
  for r in sorted_results:
224
  company_name = r["name"]
225
  if company_name in names:
@@ -235,41 +221,45 @@ def run_query(query, prompt, scrape_boost, top_k , regions, countries, is_debug,
235
  region = r["metadata"]["region"]
236
  country = r["metadata"]["country"]
237
  company_id = r["metadata"]["company_id"]
238
- card(company_id, company_name, description, score, data_type, region, country, r['data'], is_debug)
239
 
240
- st.markdown('</div>')
 
241
 
242
 
243
- def check_password():
244
- """Returns `True` if the user had the correct password."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
- def password_entered():
247
- """Checks whether a password entered by the user is correct."""
248
- if st.session_state["password"] == st.secrets["password"]:
249
- st.session_state["password_correct"] = True
250
- del st.session_state["password"] # don't store password
251
- else:
252
- st.session_state["password_correct"] = False
253
-
254
- if "password_correct" not in st.session_state:
255
- # First run, show input for password.
256
- st.text_input(
257
- "Password", type="password", on_change=password_entered, key="password"
258
- )
259
- return False
260
- elif not st.session_state["password_correct"]:
261
- # Password not correct, show input + error.
262
- st.text_input(
263
- "Password", type="password", on_change=password_entered, key="password"
264
- )
265
- st.error("😕 Password incorrect")
266
- return False
267
- else:
268
- # Password correct.
269
- return True
270
 
271
- if check_password():
272
- st.title("")
 
 
 
 
 
 
273
 
274
  st.write("""
275
  Search for a company in free text. Describe the type of company you are looking for, the problem they solve and the solution they provide. You can also copy in the description of a similar company to kick off the search.
@@ -317,6 +307,8 @@ if check_password():
317
  ''',
318
  unsafe_allow_html=True
319
  )
 
 
320
  tab_search, tab_advanced = st.tabs(["Search", "Settings"])
321
 
322
 
@@ -331,6 +323,7 @@ if check_password():
331
  scrape_boost = st.number_input('Web to API content ratio', value=1.)
332
  top_k = st.number_input('# Top Results', value=20)
333
  is_debug = st.checkbox("Debug output", value = False, key="debug")
 
334
  index_namespace = st.selectbox(label="Data Type", options=["websummarized", "web", "cbli", "all"], index=0)
335
  liked_companies = st.text_input(label="liked companies", key='liked_companies')
336
  disliked_companies = st.text_input(label="disliked companies", key='disliked_companies')
@@ -340,10 +333,16 @@ if check_password():
340
  with tab_search:
341
  #report_type = st.multiselect("Report Type", utils.get_prompts(), key="search_prompts_multiselect")
342
  query = st.text_input("Search!", "")
343
- cluster = st.checkbox("Cluster the results", value = False, key = "cluster")
 
344
  #prompt_new = st.button("New", on_click = _prompt(prompt_title, prompt))
345
 
346
  if query != "":
347
- prompt = clustering_prompt if cluster else default_prompt
348
- run_query(query, prompt, scrape_boost, top_k, region_selectbox, countries_selectbox, is_debug, index_namespace)
 
 
 
 
 
349
 
 
10
  from transformers import AutoTokenizer
11
  from sentence_transformers import SentenceTransformer
12
 
13
+ import streamlit.components.v1 as components
14
+
15
  import utils
16
 
17
  PINECONE_KEY = st.secrets["PINECONE_API_KEY"] # app.pinecone.io
 
45
  return retriever, tokenizer#, vectorstore
46
 
47
  retriever, tokenizer = init_models()
48
+ #st.session_state.messages = [{"role":"system", "content":"You are an assistant who helps users find startups to invest in."}]
49
 
50
 
51
  def card(company_id, name, description, score, data_type, region, country, metadata, is_debug):
 
62
  except Exception as e:
63
  print(f"An error occurred: {str(e)}")
64
 
 
 
 
65
  markdown = f"""
66
  <div class="row align-items-start" style="padding-bottom:10px;">
67
  <div class="col-md-8 col-sm-8">
68
  <b>{name} (<a href='https://{company_id}'>website</a>).</b>
69
+ <p style="">{description}</p>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  </div>
71
+ <div class="col-md-1 col-sm-1"><span>{country}</span></div>
72
+ <div class="col-md-1 col-sm-1"><span>{customer_problem}</span></div>
73
+ <div class="col-md-1 col-sm-1"><span>{business_model}</span></div>
74
  """
75
 
76
  if is_debug:
77
  markdown = markdown + f"""
78
+ <div class="col-md-1 col-sm-1" style="display:none;">
79
+ <button type='button' onclick="like_company({company_id});">Like</button>
80
+ <button type='button' onclick="dislike_company({company_id});">DisLike</button>
81
+ </div>
82
  <div class="col-md-1 col-sm-1">
83
  <span>{data_type}</span>
84
  <span>[Score: {score}</span>
85
  </div>
86
  """
87
  markdown = markdown + "</div>"
88
+ #print(f" markdown for {company_id}\n{markdown}")
89
+ return markdown
90
 
91
 
92
  def index_query(xq, top_k, regions=[], countries=[], index_namespace="websummarized"):
 
107
  #xc = st.session_state.index.query(xq, top_k=top_k, include_metadata=True, include_vectors = True)
108
  return xc
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
+ def run_query(query, prompt, scrape_boost, top_k , regions, countries, is_debug, index_namespace, openai_model):
112
  xq = retriever.encode([query]).tolist()
113
  try:
114
  xc = index_query(xq, top_k, regions, countries)
 
151
  # Create a summarized report focusing on the top3 companies.
152
  # For every company find its uniqueness over the other companies. Use only information from the descriptions.
153
  # """
154
+ if prompt!="":
155
+ descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']])
156
+ ntokens = len(descriptions.split(" "))
157
+
158
+ print(f"Descriptions ({ntokens} tokens):\n {descriptions[:1000]}")
159
+
160
+ prompt_txt = prompt + """
161
+ User query: {query}
162
+ Company descriptions: {descriptions}
163
+ """
164
+ prompt_template = PromptTemplate(template=prompt_txt, input_variables=["descriptions", "query"])
165
+ prompt = prompt_template.format(descriptions = descriptions, query = query)
166
+
167
+ print(f"==============================\nPrompt:\n{prompt}\n==============================\n")
168
+ new_message = {"role": "user", "content": prompt}
169
+ m_text = utils.call_openai(prompt, engine=openai_model, temp=0, top_p=1.0)
170
+
171
+ m_text
172
+
173
+ else:
174
+ new_message = {"role": "user", "content": query}
175
+
176
+ st.session_state.messages.append(new_message)
177
+ render_history()
178
+ # for message in st.session_state.messages:
179
+ # with st.chat_message(message["role"]):
180
+ # st.markdown(message["content"])
181
+ # print(f"History: \n {st.session_state.messages}")
182
 
183
  sorted_results = sorted(results, key=lambda x: x['score'], reverse=True)
184
 
 
 
185
  names = []
186
+ # list_html = """
187
+ # <h2>Companies list</h2>
188
+ # <div class="container-fluid">
189
+ # <div class="row align-items-start" style="padding-bottom:10px;">
190
+ # <div class="col-md-8 col-sm-8">
191
+ # <span>Company</span>
192
+ # </div>
193
+ # <div class="col-md-1 col-sm-1">
194
+ # <span>Country</span>
195
+ # </div>
196
+ # <div class="col-md-1 col-sm-1">
197
+ # <span>Customer Problem</span>
198
+ # </div>
199
+ # <div class="col-md-1 col-sm-1">
200
+ # <span>Business Model</span>
201
+ # </div>
202
+ # <div class="col-md-1 col-sm-1">
203
+ # Actions
204
+ # </div>
205
+ # </div>
206
+ # """
207
+ list_html = "<div class='container-fluid'>"
208
+
209
  for r in sorted_results:
210
  company_name = r["name"]
211
  if company_name in names:
 
221
  region = r["metadata"]["region"]
222
  country = r["metadata"]["country"]
223
  company_id = r["metadata"]["company_id"]
224
+ list_html = list_html + card(company_id, company_name, description, score, data_type, region, country, r['data'], is_debug)
225
 
226
+ list_html = list_html + '</div>'
227
+ st.markdown(list_html, unsafe_allow_html=True)
228
 
229
 
230
+ def render_history():
231
+ with st.session_state.history_container:
232
+
233
+ s = f"""
234
+ <div style='overflow: hidden;'>
235
+ <div id="chat_history" style='overflow-y: scroll;height: 100px;'>
236
+ """
237
+ for m in st.session_state.messages:
238
+ #print(f"Printing message\t {m['role']}: {m['content']}")
239
+ s = s + f"<div>{m['role']}: {m['content']}</div>"
240
+
241
+ s = s + f"""</div>
242
+ </div>
243
+ <script>
244
+ var el = document.getElementById("chat_history");
245
+ console.log(el.scrollTop, el.scrollHeight);
246
+ el.scrollTop = el.scrollHeight;
247
+ console.log(el.scrollTop, el.scrollHeight);
248
+ </script>
249
+ """
250
+
251
+ components.html(s, height=140)
252
+ #st.markdown(s, unsafe_allow_html=True)
253
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
 
255
+ if utils.check_password():
256
+
257
+ st.markdown("<script language='javascript'>console.log('scrolling');</script>", unsafe_allow_html=True)
258
+
259
+ if "messages" not in st.session_state:
260
+ st.session_state.messages = [{"role":"system", "content":"You are an assistant who helps users find startups to invest in."}]
261
+
262
+ st.title("Raized")
263
 
264
  st.write("""
265
  Search for a company in free text. Describe the type of company you are looking for, the problem they solve and the solution they provide. You can also copy in the description of a similar company to kick off the search.
 
307
  ''',
308
  unsafe_allow_html=True
309
  )
310
+ st.session_state.history_container = st.container()
311
+
312
  tab_search, tab_advanced = st.tabs(["Search", "Settings"])
313
 
314
 
 
323
  scrape_boost = st.number_input('Web to API content ratio', value=1.)
324
  top_k = st.number_input('# Top Results', value=20)
325
  is_debug = st.checkbox("Debug output", value = False, key="debug")
326
+ openai_model = st.selectbox(label="Model", options=["gpt-4-1106-preview", "gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-16k"], index=0, key="openai_model")
327
  index_namespace = st.selectbox(label="Data Type", options=["websummarized", "web", "cbli", "all"], index=0)
328
  liked_companies = st.text_input(label="liked companies", key='liked_companies')
329
  disliked_companies = st.text_input(label="disliked companies", key='disliked_companies')
 
333
  with tab_search:
334
  #report_type = st.multiselect("Report Type", utils.get_prompts(), key="search_prompts_multiselect")
335
  query = st.text_input("Search!", "")
336
+ #cluster = st.checkbox("Cluster the results", value = False, key = "cluster")
337
+ report_type = st.selectbox(label="Response Type", options=["company_list", "standard", "clustered"], index=0)
338
  #prompt_new = st.button("New", on_click = _prompt(prompt_title, prompt))
339
 
340
  if query != "":
341
+ if report_type=="standard":
342
+ prompt = default_prompt
343
+ elif report_type=="clustered":
344
+ prompt = clustering_prompt
345
+ else:
346
+ prompt = ""
347
+ run_query(query, prompt, scrape_boost, top_k, region_selectbox, countries_selectbox, is_debug, index_namespace, openai_model)
348
 
utils.py CHANGED
@@ -2,6 +2,7 @@ import pandas as pd
2
  import psycopg2
3
  from psycopg2 import extras
4
  import streamlit as st
 
5
 
6
  # def create_connection():
7
  # host = st.secrets["RAIZED_DB_HOST"]
@@ -17,7 +18,51 @@ import streamlit as st
17
  # )
18
 
19
  ###
20
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  def get_prompt(title):
23
  return ""
@@ -35,14 +80,21 @@ def get_prompt(title):
35
  # print(f"Results getting {title}")
36
  # return res
37
 
 
 
 
 
 
 
 
 
 
38
  default_prompt = """
39
- summarize the outcome of this search. The context is a list of company names followed by the company's description and a relevance score to the user query.
40
- the report should mention the most important companies and how they compare to each other and contain the following sections:
41
- 1) Title: query text (summarized if more than 20 tokens)
42
- 2) Best matches: Naming of the 3 companies from the list that are most similar to the search query:
43
- - summarize what they are doing
44
  - name customers and technology if they are mentioned
45
- - compare them to each other and point out what they do differently or what is their unique selling proposition
46
  ----"""
47
 
48
  clustering_prompt = """Please create a document with the following headings:
@@ -76,4 +128,21 @@ List with all the companies in this cluster. Each list item should be structured
76
  * name of the company in bold (URL of the company, country location of the company): short summary summary of what the company does (max 30 tokens)
77
  H1: How you could improve your search
78
  “I hope you have already found some interesting matches. I am happy to let you refine your search. Here are some ideas on how to find matches in relation to your original question around (“user query”):”
79
- * List of ideas on how to refine and improve the search"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import psycopg2
3
  from psycopg2 import extras
4
  import streamlit as st
5
+ import openai
6
 
7
  # def create_connection():
8
  # host = st.secrets["RAIZED_DB_HOST"]
 
18
  # )
19
 
20
  ###
21
+
22
+ def call_openai(prompt, engine="gpt-3.5-turbo", temp=0, top_p=1.0, max_tokens=4048):
23
+ try:
24
+ response = openai.ChatCompletion.create(
25
+ model=engine,
26
+ messages=st.session_state.messages,
27
+ temperature=temp,
28
+ max_tokens=max_tokens
29
+ )
30
+ print(f"Open AI response\n {response}")
31
+ text = response.choices[0].message["content"].strip()
32
+ st.session_state.messages.append({"role": "system", "content": text})
33
+ return text
34
+ except openai.error.OpenAIError as e:
35
+ print(f"An error occurred: {str(e)}")
36
+ return "Failed to generate a response."
37
+
38
+
39
+ def check_password():
40
+ """Returns `True` if the user had the correct password."""
41
+
42
+ def password_entered():
43
+ """Checks whether a password entered by the user is correct."""
44
+ if st.session_state["password"] == st.secrets["password"]:
45
+ st.session_state["password_correct"] = True
46
+ del st.session_state["password"] # don't store password
47
+ else:
48
+ st.session_state["password_correct"] = False
49
+
50
+ if "password_correct" not in st.session_state:
51
+ # First run, show input for password.
52
+ st.text_input(
53
+ "Password", type="password", on_change=password_entered, key="password"
54
+ )
55
+ return False
56
+ elif not st.session_state["password_correct"]:
57
+ # Password not correct, show input + error.
58
+ st.text_input(
59
+ "Password", type="password", on_change=password_entered, key="password"
60
+ )
61
+ st.error("😕 Password incorrect")
62
+ return False
63
+ else:
64
+ # Password correct.
65
+ return True
66
 
67
  def get_prompt(title):
68
  return ""
 
80
  # print(f"Results getting {title}")
81
  # return res
82
 
83
+ # default_prompt = """
84
+ # summarize the outcome of this search. The context is a list of company names followed by the company's description and a relevance score to the user query.
85
+ # the report should mention the most important companies and how they compare to each other and contain the following sections:
86
+ # 1) Title: query text (summarized if more than 20 tokens)
87
+ # 2) Best matches: Naming of the 3 companies from the list that are most similar to the search query:
88
+ # - summarize what they are doing
89
+ # - name customers and technology if they are mentioned
90
+ # - compare them to each other and point out what they do differently or what is their unique selling proposition
91
+ # ----"""
92
  default_prompt = """
93
+ You are an invesment assistant. Below is a user query followed by a list of company descriptions that match the user query.
94
+ the report should mention the most important companies and how they compare to each other and contain the following sections
95
+ - summarize what those companies they are doing
 
 
96
  - name customers and technology if they are mentioned
97
+ - compare the companies to each other and point out what they do differently or what is their unique selling proposition
98
  ----"""
99
 
100
  clustering_prompt = """Please create a document with the following headings:
 
128
  * name of the company in bold (URL of the company, country location of the company): short summary summary of what the company does (max 30 tokens)
129
  H1: How you could improve your search
130
  “I hope you have already found some interesting matches. I am happy to let you refine your search. Here are some ideas on how to find matches in relation to your original question around (“user query”):”
131
+ * List of ideas on how to refine and improve the search"""
132
+
133
+
134
+
135
+
136
+ def on_prompt_selected():
137
+ title = st.session_state.advanced_prompts_select
138
+ new_prompt = utils.get_prompt(title)
139
+ if len(new_prompt)>0 and len(new_prompt[0])>0:
140
+ print(f"Got a prompt for title {title}\n {new_prompt[0]}")
141
+ st.session_state.prompt_title_editable = st.session_state.advanced_prompts_select
142
+ st.session_state.advanced_prompt_content = new_prompt[0]
143
+ else:
144
+ print(f"No results for title {st.session_state.advanced_prompts_select}")
145
+
146
+
147
+
148
+