hanoch.rahimi@gmail commited on
Commit
aac3522
·
1 Parent(s): e54b3e0

moved organization ot secrets

Browse files
Files changed (4) hide show
  1. app.py +27 -26
  2. openai_utils.py +18 -0
  3. requirements.txt +1 -1
  4. utils.py +15 -20
app.py CHANGED
@@ -13,12 +13,13 @@ from sentence_transformers import SentenceTransformer
13
 
14
  import streamlit.components.v1 as components
15
 
16
- import utils
 
17
 
18
  PINECONE_KEY = st.secrets["PINECONE_API_KEY"] # app.pinecone.io
19
  OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"] # app.pinecone.io
20
  PINE_CONE_ENVIRONMENT = st.secrets["PINE_CONE_ENVIRONMENT"] # app.pinecone.io
21
-
22
  model_name = 'text-embedding-ada-002'
23
 
24
  embed = OpenAIEmbeddings(
@@ -43,7 +44,6 @@ def init_models():
43
  #reader = pipeline(tokenizer=model_name, model=model_name, task='question-answering')
44
  tokenizer = AutoTokenizer.from_pretrained(model_name)
45
  #vectorstore = Pinecone(st.session_state.index, embed.embed_query, text_field)
46
- st.session_state.openai_client = openai.OpenAI(api_key = OPENAI_API_KEY,organization='org-EEpryZYLlh0mZJOGxVko32qP')
47
  # client.beta.assistants.create(
48
  # instructions=utils.assistant_instructions,
49
  # model="gpt-4-1106-preview",
@@ -51,6 +51,7 @@ def init_models():
51
  return retriever, tokenizer#, vectorstore
52
 
53
 
 
54
  retriever, tokenizer = init_models()
55
  #st.session_state.messages = [{"role":"system", "content":"You are an assistant who helps users find startups to invest in."}]
56
 
@@ -115,7 +116,7 @@ def index_query(xq, top_k, regions=[], countries=[], index_namespace="websummari
115
  return xc
116
 
117
 
118
- def run_query(query, prompt, scrape_boost, top_k , regions, countries, is_debug, index_namespace, openai_model):
119
  xq = retriever.encode([query]).tolist()
120
  try:
121
  xc = index_query(xq, top_k, regions, countries)
@@ -129,8 +130,8 @@ def run_query(query, prompt, scrape_boost, top_k , regions, countries, is_debug,
129
  for match in xc['matches']:
130
  #answer = reader(question=query, context=match["metadata"]['context'])
131
  score = match['score']
132
- if 'type' in match['metadata'] and match['metadata']['type']!='description-webcontent' and scrape_boost>0:
133
- score = score / scrape_boost
134
  answer = {'score': score, 'metadata': match['metadata']}
135
  if match['id'].endswith("_description"):
136
  answer['id'] = match['id'][:-12]
@@ -158,7 +159,8 @@ def run_query(query, prompt, scrape_boost, top_k , regions, countries, is_debug,
158
  # Create a summarized report focusing on the top3 companies.
159
  # For every company find its uniqueness over the other companies. Use only information from the descriptions.
160
  # """
161
- if prompt!="":
 
162
  descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']])
163
  ntokens = len(descriptions.split(" "))
164
 
@@ -172,11 +174,12 @@ def run_query(query, prompt, scrape_boost, top_k , regions, countries, is_debug,
172
  prompt = prompt_template.format(descriptions = descriptions, query = query)
173
 
174
  print(f"==============================\nPrompt:\n{prompt}\n==============================\n")
175
- new_message = {"role": "user", "content": prompt}
176
- m_text = utils.call_openai(prompt, engine=openai_model, temp=0, top_p=1.0)
177
 
178
  m_text
179
 
 
 
180
  else:
181
  new_message = {"role": "user", "content": query}
182
 
@@ -238,39 +241,36 @@ def render_history():
238
  with st.session_state.history_container:
239
 
240
  s = f"""
241
- <div style='overflow: hidden;'>
242
- <div id="chat_history" style='overflow-y: scroll;height: 100px;'>
243
  """
244
  for m in st.session_state.messages:
245
  #print(f"Printing message\t {m['role']}: {m['content']}")
246
- s = s + f"<div>{m['role']}: {m['content']}</div>"
247
 
248
  s = s + f"""</div>
249
  </div>
250
  <script>
251
  var el = document.getElementById("chat_history");
252
- console.log(el.scrollTop, el.scrollHeight);
253
  el.scrollTop = el.scrollHeight;
254
- console.log(el.scrollTop, el.scrollHeight);
255
  </script>
256
  """
257
 
258
- components.html(s, height=140)
259
  #st.markdown(s, unsafe_allow_html=True)
260
 
261
-
262
  if utils.check_password():
263
 
264
  st.markdown("<script language='javascript'>console.log('scrolling');</script>", unsafe_allow_html=True)
265
 
266
- if "messages" not in st.session_state:
267
- st.session_state.messages = [{"role":"system", "content":"You are an assistant who helps users find startups to invest in."}]
 
 
268
 
269
- st.title("Raized")
270
 
271
- st.write("""
272
- Search for a company in free text. Describe the type of company you are looking for, the problem they solve and the solution they provide. You can also copy in the description of a similar company to kick off the search.
273
- """)
274
 
275
  st.markdown("""
276
  <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
@@ -278,6 +278,7 @@ if utils.check_password():
278
  with open("data/countries.json", "r") as f:
279
  countries = json.load(f)['countries']
280
  header = st.sidebar.markdown("Filters")
 
281
  countries_selectbox = st.sidebar.multiselect("Country", countries, default=[])
282
  all_regions = ('Africa', 'Europe', 'Asia & Pacific', 'North America', 'South/Latin America')
283
  region_selectbox = st.sidebar.multiselect("Region", all_regions, default=all_regions)
@@ -314,7 +315,6 @@ if utils.check_password():
314
  ''',
315
  unsafe_allow_html=True
316
  )
317
- st.session_state.history_container = st.container()
318
 
319
  tab_search, tab_advanced = st.tabs(["Search", "Settings"])
320
 
@@ -322,12 +322,13 @@ if utils.check_password():
322
  with tab_advanced:
323
  #prompt_title = st.selectbox("Report Type", index = 0, options = utils.get_prompts(), on_change=on_prompt_selected, key="advanced_prompts_select", )
324
  #prompt_title_editable = st.text_input("Title", key="prompt_title_editable")
 
325
  default_prompt = st.text_area("Default Prompt", value = utils.default_prompt, height=400, key="advanced_default_prompt_content")
326
  clustering_prompt = st.text_area("Clustering Prompt", value = utils.clustering_prompt, height=400, key="advanced_clustering_prompt_content")
327
  #prompt_new = st.button("New", on_click = _prompt(prompt_title, prompt))
328
  #prompt_delete = st.button("Del", on_click = utils.del_prompt(prompt_title_editable))
329
  #prompt_save = st.button("Save", on_click = utils.save_prompt(prompt_title_editable, prompt))
330
- scrape_boost = st.number_input('Web to API content ratio', value=1.)
331
  top_k = st.number_input('# Top Results', value=20)
332
  is_debug = st.checkbox("Debug output", value = False, key="debug")
333
  openai_model = st.selectbox(label="Model", options=["gpt-4-1106-preview", "gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-16k"], index=0, key="openai_model")
@@ -339,9 +340,9 @@ if utils.check_password():
339
 
340
  with tab_search:
341
  #report_type = st.multiselect("Report Type", utils.get_prompts(), key="search_prompts_multiselect")
 
342
  query = st.text_input("Search!", "")
343
  #cluster = st.checkbox("Cluster the results", value = False, key = "cluster")
344
- report_type = st.selectbox(label="Response Type", options=["company_list", "standard", "clustered"], index=0)
345
  #prompt_new = st.button("New", on_click = _prompt(prompt_title, prompt))
346
 
347
  if query != "":
@@ -351,5 +352,5 @@ if utils.check_password():
351
  prompt = clustering_prompt
352
  else:
353
  prompt = ""
354
- run_query(query, prompt, scrape_boost, top_k, region_selectbox, countries_selectbox, is_debug, index_namespace, openai_model)
355
 
 
13
 
14
  import streamlit.components.v1 as components
15
 
16
+ import utils
17
+ import openai_utils as oai
18
 
19
  PINECONE_KEY = st.secrets["PINECONE_API_KEY"] # app.pinecone.io
20
  OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"] # app.pinecone.io
21
  PINE_CONE_ENVIRONMENT = st.secrets["PINE_CONE_ENVIRONMENT"] # app.pinecone.io
22
+ OPENAI_ORGANIZATION_ID = st.secrets["OPENAI_ORGANIZATION_ID"]
23
  model_name = 'text-embedding-ada-002'
24
 
25
  embed = OpenAIEmbeddings(
 
44
  #reader = pipeline(tokenizer=model_name, model=model_name, task='question-answering')
45
  tokenizer = AutoTokenizer.from_pretrained(model_name)
46
  #vectorstore = Pinecone(st.session_state.index, embed.embed_query, text_field)
 
47
  # client.beta.assistants.create(
48
  # instructions=utils.assistant_instructions,
49
  # model="gpt-4-1106-preview",
 
51
  return retriever, tokenizer#, vectorstore
52
 
53
 
54
+ st.session_state.openai_client = openai.OpenAI(api_key = OPENAI_API_KEY,organization=OPENAI_ORGANIZATION_ID)
55
  retriever, tokenizer = init_models()
56
  #st.session_state.messages = [{"role":"system", "content":"You are an assistant who helps users find startups to invest in."}]
57
 
 
116
  return xc
117
 
118
 
119
+ def run_query(query, prompt, top_k , regions, countries, is_debug, index_namespace, openai_model):
120
  xq = retriever.encode([query]).tolist()
121
  try:
122
  xc = index_query(xq, top_k, regions, countries)
 
130
  for match in xc['matches']:
131
  #answer = reader(question=query, context=match["metadata"]['context'])
132
  score = match['score']
133
+ # if 'type' in match['metadata'] and match['metadata']['type']!='description-webcontent' and scrape_boost>0:
134
+ # score = score / scrape_boost
135
  answer = {'score': score, 'metadata': match['metadata']}
136
  if match['id'].endswith("_description"):
137
  answer['id'] = match['id'][:-12]
 
159
  # Create a summarized report focusing on the top3 companies.
160
  # For every company find its uniqueness over the other companies. Use only information from the descriptions.
161
  # """
162
+ if prompt!="" or st.session_state.new_conversation:
163
+ st.session_state.new_conversation = False
164
  descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']])
165
  ntokens = len(descriptions.split(" "))
166
 
 
174
  prompt = prompt_template.format(descriptions = descriptions, query = query)
175
 
176
  print(f"==============================\nPrompt:\n{prompt}\n==============================\n")
177
+ m_text = oai.call_openai(prompt, engine=openai_model, temp=0, top_p=1.0)
 
178
 
179
  m_text
180
 
181
+ new_message = {"role": "user", "content": query}
182
+
183
  else:
184
  new_message = {"role": "user", "content": query}
185
 
 
241
  with st.session_state.history_container:
242
 
243
  s = f"""
244
+ <div style='overflow: hidden; padding:10px 0px;'>
245
+ <div id="chat_history" style='overflow-y: scroll;height: 200px;'>
246
  """
247
  for m in st.session_state.messages:
248
  #print(f"Printing message\t {m['role']}: {m['content']}")
249
+ s = s + f"<div class='chat_message'><b>{m['role']}</b>: {m['content']}</div>"
250
 
251
  s = s + f"""</div>
252
  </div>
253
  <script>
254
  var el = document.getElementById("chat_history");
 
255
  el.scrollTop = el.scrollHeight;
 
256
  </script>
257
  """
258
 
259
+ components.html(s, height=220)
260
  #st.markdown(s, unsafe_allow_html=True)
261
 
 
262
  if utils.check_password():
263
 
264
  st.markdown("<script language='javascript'>console.log('scrolling');</script>", unsafe_allow_html=True)
265
 
266
+ if st.sidebar.button("New Conversation") or "messages" not in st.session_state:
267
+ st.session_state.new_conversation = True
268
+ st.session_state.messages = [{"role":"system", "content":"Hello. I'm your startups discovery assistant."}]
269
+
270
 
271
+ st.title("Raized- Startups discovery demo")
272
 
273
+ #st.write("Search for a company in free text. Describe the type of company you are looking for, the problem they solve and the solution they provide. You can also copy in the description of a similar company to kick off the search.")
 
 
274
 
275
  st.markdown("""
276
  <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
 
278
  with open("data/countries.json", "r") as f:
279
  countries = json.load(f)['countries']
280
  header = st.sidebar.markdown("Filters")
281
+ #new_conversation = st.sidebar.button("New Conversation", key="new_conversation")
282
  countries_selectbox = st.sidebar.multiselect("Country", countries, default=[])
283
  all_regions = ('Africa', 'Europe', 'Asia & Pacific', 'North America', 'South/Latin America')
284
  region_selectbox = st.sidebar.multiselect("Region", all_regions, default=all_regions)
 
315
  ''',
316
  unsafe_allow_html=True
317
  )
 
318
 
319
  tab_search, tab_advanced = st.tabs(["Search", "Settings"])
320
 
 
322
  with tab_advanced:
323
  #prompt_title = st.selectbox("Report Type", index = 0, options = utils.get_prompts(), on_change=on_prompt_selected, key="advanced_prompts_select", )
324
  #prompt_title_editable = st.text_input("Title", key="prompt_title_editable")
325
+ report_type = st.selectbox(label="Response Type", options=["company_list", "standard", "clustered"], index=0)
326
  default_prompt = st.text_area("Default Prompt", value = utils.default_prompt, height=400, key="advanced_default_prompt_content")
327
  clustering_prompt = st.text_area("Clustering Prompt", value = utils.clustering_prompt, height=400, key="advanced_clustering_prompt_content")
328
  #prompt_new = st.button("New", on_click = _prompt(prompt_title, prompt))
329
  #prompt_delete = st.button("Del", on_click = utils.del_prompt(prompt_title_editable))
330
  #prompt_save = st.button("Save", on_click = utils.save_prompt(prompt_title_editable, prompt))
331
+ #scrape_boost = st.number_input('Web to API content ratio', value=1.)
332
  top_k = st.number_input('# Top Results', value=20)
333
  is_debug = st.checkbox("Debug output", value = False, key="debug")
334
  openai_model = st.selectbox(label="Model", options=["gpt-4-1106-preview", "gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-16k"], index=0, key="openai_model")
 
340
 
341
  with tab_search:
342
  #report_type = st.multiselect("Report Type", utils.get_prompts(), key="search_prompts_multiselect")
343
+ st.session_state.history_container = st.container()
344
  query = st.text_input("Search!", "")
345
  #cluster = st.checkbox("Cluster the results", value = False, key = "cluster")
 
346
  #prompt_new = st.button("New", on_click = _prompt(prompt_title, prompt))
347
 
348
  if query != "":
 
352
  prompt = clustering_prompt
353
  else:
354
  prompt = ""
355
+ run_query(query, prompt, top_k, region_selectbox, countries_selectbox, is_debug, index_namespace, openai_model)
356
 
openai_utils.py CHANGED
@@ -2,6 +2,24 @@ import time
2
  import streamlit as st
3
 
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  def send_message(role, content):
6
  message = st.session_state.openai_client.beta.threads.messages.create(
7
  thread_id=st.session_state.assistant_thread.id,
 
2
  import streamlit as st
3
 
4
 
5
+ def call_openai(prompt, engine="gpt-3.5-turbo", temp=0, top_p=1.0, max_tokens=4048):
6
+ try:
7
+ response = st.session_state.openai_client.chat.completions.create(
8
+ model=engine,
9
+ messages=st.session_state.messages,
10
+ temperature=temp,
11
+ max_tokens=max_tokens
12
+ )
13
+ print(f"====================\nOpen AI response\n {response}\n====================\n")
14
+ text = response.choices[0].message.content.strip()
15
+ st.session_state.messages.append({"role": "system", "content": text})
16
+ return text
17
+ except Exception as e:
18
+ #except openai.error.OpenAIError as e:
19
+ print(f"An error occurred: {str(e)}")
20
+ return "Failed to generate a response."
21
+
22
+
23
  def send_message(role, content):
24
  message = st.session_state.openai_client.beta.threads.messages.create(
25
  thread_id=st.session_state.assistant_thread.id,
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  langchain
2
- openai
3
  pinecone-client
4
  psycopg2-binary==2.8.6
5
  sentence_transformers
 
1
  langchain
2
+ openai==1.2.4
3
  pinecone-client
4
  psycopg2-binary==2.8.6
5
  sentence_transformers
utils.py CHANGED
@@ -19,23 +19,6 @@ import openai
19
 
20
  ###
21
 
22
- def call_openai(prompt, engine="gpt-3.5-turbo", temp=0, top_p=1.0, max_tokens=4048):
23
- try:
24
- response = st.session_state.openai_client.chat.completions.create(
25
- model=engine,
26
- messages=st.session_state.messages,
27
- temperature=temp,
28
- max_tokens=max_tokens
29
- )
30
- print(f"Open AI response\n {response}")
31
- text = response.choices[0].message.content.strip()
32
- st.session_state.messages.append({"role": "system", "content": text})
33
- return text
34
- except Exception as e:
35
- #except openai.error.OpenAIError as e:
36
- print(f"An error occurred: {str(e)}")
37
- return "Failed to generate a response."
38
-
39
 
40
  def check_password():
41
  """Returns `True` if the user had the correct password."""
@@ -113,14 +96,26 @@ Also name the ranking criteria and suggest how to combine them to best meet the
113
  # - name customers and technology if they are mentioned
114
  # - compare them to each other and point out what they do differently or what is their unique selling proposition
115
  # ----"""
 
116
  default_prompt = """
117
- You are an invesment assistant. Below is a user query followed by a list of company descriptions that match the user query.
118
- the report should mention the most important companies and how they compare to each other and contain the following sections
 
 
 
119
  - summarize what those companies they are doing
120
  - name customers and technology if they are mentioned
121
  - compare the companies to each other and point out what they do differently or what is their unique selling proposition
122
  ----"""
123
 
 
 
 
 
 
 
 
 
124
  clustering_prompt = """Please create a document with the following headings:
125
  H2: Recap of your question
126
  H2: Clusters of relevant companies
@@ -159,7 +154,7 @@ H1: How you could improve your search
159
 
160
  def on_prompt_selected():
161
  title = st.session_state.advanced_prompts_select
162
- new_prompt = utils.get_prompt(title)
163
  if len(new_prompt)>0 and len(new_prompt[0])>0:
164
  print(f"Got a prompt for title {title}\n {new_prompt[0]}")
165
  st.session_state.prompt_title_editable = st.session_state.advanced_prompts_select
 
19
 
20
  ###
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  def check_password():
24
  """Returns `True` if the user had the correct password."""
 
96
  # - name customers and technology if they are mentioned
97
  # - compare them to each other and point out what they do differently or what is their unique selling proposition
98
  # ----"""
99
+
100
  default_prompt = """
101
+ You are an assistant and your job is to help the user discover and analyze startups companies. You first need to understand what type of startups the user is looking and then create a report with an analysis of companies relevant to the user's query.
102
+ Use only information from the explicit list of companies provided!
103
+ Below is the user query followed by a list of company descriptions that match the user query. If you don't have enough information in the user query, offer the user ways to improve the query.
104
+ Don't teach the user about investment though.
105
+ The report should mention the most important companies and how they compare to each other and contain the following sections
106
  - summarize what those companies they are doing
107
  - name customers and technology if they are mentioned
108
  - compare the companies to each other and point out what they do differently or what is their unique selling proposition
109
  ----"""
110
 
111
+ query_finetune_prompt = """
112
+ You are an assistant and your job is to help the user discover and analyze startups companies.
113
+ You first need to understand what type of startups the user is looking and then create a report with an analysis of companies relevant to the user's query.
114
+ """
115
+
116
+ summarization_prompt = """
117
+ """
118
+
119
  clustering_prompt = """Please create a document with the following headings:
120
  H2: Recap of your question
121
  H2: Clusters of relevant companies
 
154
 
155
  def on_prompt_selected():
156
  title = st.session_state.advanced_prompts_select
157
+ new_prompt = get_prompt(title)
158
  if len(new_prompt)>0 and len(new_prompt[0])>0:
159
  print(f"Got a prompt for title {title}\n {new_prompt[0]}")
160
  st.session_state.prompt_title_editable = st.session_state.advanced_prompts_select