Spaces:
				
			
			
	
			
			
		Paused
		
	
	
	
			
			
	
	
	
	
		
		
		Paused
		
	| import asyncio | |
| if hasattr(asyncio, "WindowsSelectorEventLoopPolicy"): | |
| asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) | |
| import json | |
| import logging | |
| from sys import exc_info | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.DEBUG) | |
| import streamlit as st | |
| from googleai import send_message as google_send_message, init_googleai | |
| from langchain.chains import RetrievalQA | |
| from langchain_community.embeddings import OpenAIEmbeddings | |
| from langchain.prompts import PromptTemplate | |
| from langchain.schema import AIMessage, HumanMessage, SystemMessage | |
| import pandas as pd | |
| from PIL import Image | |
| from streamlit.runtime.state import session_state | |
| import openai | |
| from transformers import AutoTokenizer | |
| from sentence_transformers import SentenceTransformer | |
| import streamlit.components.v1 as components | |
| # st.set_page_config( | |
| # layout="wide", | |
| # initial_sidebar_state="collapsed", | |
| # page_title="RaizedAI Startup Discovery Assistant", | |
| # #page_icon=":robot:" | |
| # ) | |
| import utils | |
| import openai_utils as oai | |
| from streamlit_extras.stylable_container import stylable_container | |
| # OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"] # app.pinecone.io | |
| #model_name = 'text-embedding-ada-002' | |
| # embed = OpenAIEmbeddings( | |
| # model=model_name, | |
| # openai_api_key=OPENAI_API_KEY | |
| # ) | |
| #"🤖", | |
| #st.image("resources/raized_logo.png") | |
| assistant_avatar = Image.open('resources/raized_logo.png') | |
| carddict = { | |
| "name": [], | |
| "company_id": [], | |
| "description": [], | |
| "country": [], | |
| "customer_problem": [], | |
| "target_customer": [], | |
| "business_model": [] | |
| } | |
| def init_models(): | |
| retriever = SentenceTransformer("msmarco-distilbert-base-v4") | |
| #model_name = "sentence-transformers/all-MiniLM-L6-v2" | |
| model_name = "sentence-transformers/msmarco-distilbert-base-v4" | |
| #retriever = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") | |
| #reader = pipeline(tokenizer=model_name, model=model_name, task='question-answering') | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| return retriever, tokenizer#, vectorstore | |
| def init_openai(): | |
| st.session_state.openai_client = oai.get_client() | |
| assistants = st.session_state.openai_client.assistants.list( | |
| order="desc", | |
| limit="20", | |
| ) | |
| return assistants | |
| assistants = init_openai() | |
| retriever, tokenizer = init_models() | |
| st.session_state.retriever = retriever | |
| # AVATAR_PATHS = {"assistant": st.image("resources/raized_logo.png"), | |
| # "user": "👩⚖️"} | |
| #st.session_state.messages = [{"role":"system", "content":"You are an assistant who helps users find startups to invest in."}] | |
| def card(company_id, name, description, score, data_type, region, country, metadata, is_debug): | |
| if 'Summary' in metadata: | |
| description = metadata['Summary'] | |
| customer_problem = metadata['Customer problem'] if 'Customer problem' in metadata else "" | |
| target_customer = metadata['Target customer'] if 'Target customer' in metadata else "" | |
| business_model = "" | |
| if 'Business model' in metadata: | |
| try: | |
| business_model = metadata['Business model'] | |
| #business_model = json.loads(metadata['Business model']) | |
| except Exception as e: | |
| print(f"An error occurred: {str(e)}") | |
| markdown = f""" | |
| <div class="row align-items-start" style="padding-bottom:10px;"> | |
| <div class="col-md-8 col-sm-8"> | |
| <b>{name} (<a href='https://{company_id}'>website</a>).</b> | |
| <p style="">{description}</p> | |
| </div> | |
| <div class="col-md-1 col-sm-1"><span>{country}</span></div> | |
| <div class="col-md-1 col-sm-1"><span>{customer_problem}</span></div> | |
| <div class="col-md-1 col-sm-1"><span>{target_customer}</span></div> | |
| <div class="col-md-1 col-sm-1"><span>{business_model}</span></div> | |
| """ | |
| business_model_str = ", ".join(business_model) | |
| company_id_url = "https://" + company_id | |
| carddict["name"].append(name) | |
| carddict["company_id"].append(company_id_url) | |
| carddict["description"].append(description) | |
| carddict["country"].append(country) | |
| carddict["customer_problem"].append(customer_problem) | |
| carddict["target_customer"].append(target_customer) | |
| carddict["business_model"].append(business_model_str) | |
| if is_debug: | |
| markdown = markdown + f""" | |
| <div class="col-md-1 col-sm-1" style="display:none;"> | |
| <button type='button' onclick="like_company({company_id});">Like</button> | |
| <button type='button' onclick="dislike_company({company_id});">DisLike</button> | |
| </div> | |
| <div class="col-md-1 col-sm-1"> | |
| <span>{data_type}</span> | |
| <span>[Score: {score}</span> | |
| </div> | |
| """ | |
| markdown = markdown + "</div>" | |
| #print(f" markdown for {company_id}\n{markdown}") | |
| return markdown | |
| def run_query(query, report_type, top_k , regions, countries, is_debug, index_namespace, openai_model, default_prompt): | |
| #Summarize the results | |
| # prompt_txt = """ | |
| # You are a venture capitalist analyst. Below are descriptions of startup companies that are relevant to the user with their relevancy score. | |
| # Create a summarized report focusing on the top3 companies. | |
| # For every company find its uniqueness over the other companies. Use only information from the descriptions. | |
| # """ | |
| content_container = st.container() #, col_sidepanel = st.columns([4, 1], gap="small") | |
| if report_type == "gemini": | |
| try: | |
| logger.debug(f"User: {query}") | |
| response = google_send_message(query) | |
| response = response['output'] | |
| logger.debug(f"Agent: {response }") | |
| with content_container: | |
| with st.chat_message(name = 'User'): | |
| st.write(query) | |
| with st.chat_message(name = 'Agent', avatar = assistant_avatar): | |
| st.write(response) | |
| except Exception as e: | |
| logger.exception(f"Error processing user message", exc_info=e) | |
| else: | |
| if report_type=="guided": | |
| prompt_txt = utils.query_finetune_prompt + """ | |
| User query: {query} | |
| """ | |
| prompt_template = PromptTemplate(template=prompt_txt, input_variables=["query"]) | |
| prompt = prompt_template.format(query = query) | |
| m_text = oai.call_openai(prompt, engine=openai_model, temp=0, top_p=1.0, max_tokens=20, log_message = False) | |
| print(f"Keywords: {m_text}") | |
| results = utils.search_index(m_text, top_k, regions, countries, retriever, index_namespace) | |
| descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']]) | |
| ntokens = len(descriptions.split(" ")) | |
| print(f"Descriptions ({ntokens} tokens):\n {descriptions[:1000]}") | |
| prompt_txt = utils.summarization_prompt + """ | |
| User query: {query} | |
| Company descriptions: {descriptions} | |
| """ | |
| prompt_template = PromptTemplate(template=prompt_txt, input_variables=["descriptions", "query"]) | |
| prompt = prompt_template.format(descriptions = descriptions, query = query) | |
| print(f"==============================\nPrompt:\n{prompt}\n==============================\n") | |
| m_text = oai.call_openai(prompt, engine=openai_model, temp=0, top_p=1.0) | |
| m_text | |
| elif report_type=="company_list": # or st.session_state.new_conversation: | |
| results = utils.search_index(query, top_k, regions, countries, retriever, index_namespace) | |
| descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']]) | |
| elif report_type=="assistant": | |
| #results = utils.search_index(query, top_k, regions, countries, retriever, index_namespace) | |
| #descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']]) | |
| messages = oai.call_assistant(query, engine=openai_model) | |
| st.session_state.messages = messages | |
| results = st.session_state.db_search_results | |
| if not messages is None: | |
| with content_container: | |
| for message in list(messages)[::-1]: | |
| if hasattr(message, 'role'): | |
| # print(f"\n-----\nMessage: {message}\n") | |
| # with st.chat_message(name = message.role): | |
| # st.write(message.content[0].text.value) | |
| if message.role == "assistant": | |
| with st.chat_message(name = message.role, avatar = assistant_avatar): | |
| st.write(message.content[0].text.value) | |
| else: | |
| with st.chat_message(name = message.role): | |
| st.write(message.content[0].text.value) | |
| # st.session_state.messages.append({"role": "user", "content": query}) | |
| # st.session_state.messages.append({"role": "system", "content": m_text}) | |
| else: | |
| st.session_state.new_conversation = False | |
| results = utils.search_index(query, top_k, regions, countries, retriever, index_namespace) | |
| descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']]) | |
| ntokens = len(descriptions.split(" ")) | |
| print(f"Descriptions ({ntokens} tokens):\n {descriptions[:1000]}") | |
| prompt = utils.clustering_prompt if report_type=="clustered" else utils.default_prompt | |
| prompt_txt = prompt + """ | |
| User query: {query} | |
| Company descriptions: {descriptions} | |
| """ | |
| prompt_template = PromptTemplate(template=prompt_txt, input_variables=["descriptions", "query"]) | |
| prompt = prompt_template.format(descriptions = descriptions, query = query) | |
| print(f"==============================\nPrompt:\n{prompt[:1000]}\n==============================\n") | |
| m_text = oai.call_openai(prompt, engine=openai_model, temp=0, top_p=1.0) | |
| m_text | |
| st.session_state.messages.append({"role": "user", "content": query}) | |
| i = m_text.find("-----") | |
| i = 0 if i<0 else i | |
| st.session_state.messages.append({"role": "system", "content": m_text[:i]}) | |
| #render_history() | |
| # for message in st.session_state.messages: | |
| # with st.chat_message(message["role"]): | |
| # st.markdown(message["content"]) | |
| # print(f"History: \n {st.session_state.messages}") | |
| sorted_results = sorted(results, key=lambda x: x['score'], reverse=True) | |
| names = [] | |
| # list_html = """ | |
| # <h2>Companies list</h2> | |
| # <div class="container-fluid"> | |
| # <div class="row align-items-start" style="padding-bottom:10px;"> | |
| # <div class="col-md-8 col-sm-8"> | |
| # <span>Company</span> | |
| # </div> | |
| # <div class="col-md-1 col-sm-1"> | |
| # <span>Country</span> | |
| # </div> | |
| # <div class="col-md-1 col-sm-1"> | |
| # <span>Customer Problem</span> | |
| # </div> | |
| # <div class="col-md-1 col-sm-1"> | |
| # <span>Business Model</span> | |
| # </div> | |
| # <div class="col-md-1 col-sm-1"> | |
| # Actions | |
| # </div> | |
| # </div> | |
| # """ | |
| list_html = "<div class='container-fluid'>" | |
| locations = set() | |
| for r in sorted_results: | |
| company_name = r["name"] | |
| if company_name in names: | |
| continue | |
| else: | |
| names.append(company_name) | |
| description = r["description"] #.replace(company_name, f"<mark>{company_name}</mark>") | |
| if description is None or len(description.strip())<10: | |
| continue | |
| score = round(r["score"], 4) | |
| data_type = r["metadata"]["type"] if "type" in r["metadata"] else "" | |
| region = r["metadata"]["region"] | |
| country = r["metadata"]["country"] | |
| company_id = r["metadata"]["company_id"] | |
| locations.add(country) | |
| list_html = list_html + card(company_id, company_name, description, score, data_type, region, country, r['data'], is_debug) | |
| list_html = list_html + '</div>' | |
| pins = country_geo[country_geo['name'].isin(locations)].loc[:, ['latitude', 'longitude']] | |
| if len(pins)>0: | |
| with st.expander("Map view"): | |
| st.map(pins) | |
| #st.markdown(list_html, unsafe_allow_html=True) | |
| df = pd.DataFrame.from_dict(carddict, orient="columns") | |
| if len(df)>0: | |
| df.index += 1 | |
| with content_container: | |
| st.dataframe(df, | |
| hide_index=False, | |
| column_config ={ | |
| "name": st.column_config.TextColumn("Name"), | |
| "company_id": st.column_config.LinkColumn("Link"), | |
| "description": st.column_config.TextColumn("Description"), | |
| "country": st.column_config.TextColumn("Country", width="small"), | |
| "customer_problem": st.column_config.TextColumn("Customer problem"), | |
| "target_customer": st.column_config.TextColumn(label="Target customer", width="small"), | |
| "business_model": st.column_config.TextColumn(label="Business model") | |
| }, | |
| use_container_width=True) | |
| st.session_state.last_user_query = query | |
| def query_sent(): | |
| st.session_state.user_query = "" | |
| def find_default_assistant_idx(assistants): | |
| default_assistant_id = 'asst_8aSvGL075pmE1r8GAjjymu85' #startup discovery 3 steps | |
| for idx, assistant in enumerate(assistants): | |
| if assistant.id == default_assistant_id: | |
| return idx | |
| return 0 | |
| def render_history(): | |
| with st.session_state.history_container: | |
| s = f""" | |
| <div style='overflow: hidden; padding:10px 0px;'> | |
| <div id="chat_history" style='overflow-y: scroll;height: 200px;'> | |
| """ | |
| for m in st.session_state.messages: | |
| #print(f"Printing message\t {m['role']}: {m['content']}") | |
| s = s + f"<div class='chat_message'><b>{m['role']}</b>: {m['content']}</div>" | |
| s = s + f"""</div> | |
| </div> | |
| <script> | |
| var el = document.getElementById("chat_history"); | |
| el.scrollTop = el.scrollHeight; | |
| </script> | |
| """ | |
| components.html(s, height=220) | |
| #st.markdown(s, unsafe_allow_html=True) | |
| if not 'submitted_query' in st.session_state: | |
| st.session_state.submitted_query = '' | |
| if not 'messages' in st.session_state: | |
| st.session_state.messages = [] | |
| if not 'last_user_query' in st.session_state: | |
| st.session_state.last_user_query = '' | |
| if utils.check_password(): | |
| st.markdown("<script language='javascript'>console.log('scrolling');</script>", unsafe_allow_html=True) | |
| if st.sidebar.button("New Conversation") or "messages" not in st.session_state: | |
| st.session_state.assistant_thread = st.session_state.openai_client.beta.threads.create() | |
| st.session_state.new_conversation = True | |
| st.session_state.messages = [] | |
| st.markdown("<h1 style='text-align: center; color: red; position: relative; top: -3rem;'>Raized.AI – Startups discovery demo</h1>", unsafe_allow_html=True) | |
| #st.write("Search for a company in free text. Describe the type of company you are looking for, the problem they solve and the solution they provide. You can also copy in the description of a similar company to kick off the search.") | |
| st.markdown(""" | |
| <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous"> | |
| """, unsafe_allow_html=True) | |
| with open("data/countries.json", "r") as f: | |
| countries = json.load(f)['countries'] | |
| header = st.sidebar.markdown("Filters") | |
| #new_conversation = st.sidebar.button("New Conversation", key="new_conversation") | |
| countries_selectbox = st.sidebar.multiselect("Country", countries, default=[]) | |
| all_regions = ('Africa', 'Europe', 'Asia & Pacific', 'North America', 'South/Latin America') | |
| region_selectbox = st.sidebar.multiselect("Region", all_regions, default=all_regions) | |
| all_bizmodels = ('B2B', 'B2C', 'eCommerce & Marketplace', 'Manufacturing', 'SaaS', 'Advertising', 'Commission', 'Subscription') | |
| bizmodel_selectbox = st.sidebar.multiselect("Business Model", all_bizmodels, default=all_bizmodels) | |
| st.markdown( | |
| ''' | |
| <script> | |
| function like_company(company_id) { | |
| console.log("Company " + company_id + " Liked!"); | |
| } | |
| function dislike_company(company_id) { | |
| console.log("Company " + company_id + " Disliked!"); | |
| } | |
| </script> | |
| <style> | |
| .sidebar .sidebar-content {{ | |
| width: 375px; | |
| }} | |
| </style> | |
| ''', | |
| unsafe_allow_html=True | |
| ) | |
| #tab_search, tab_advanced = st.tabs(["Search", "Settings"]) | |
| tab_search = st.container() | |
| with tab_search: | |
| #report_type = st.multiselect("Report Type", utils.get_prompts(), key="search_prompts_multiselect") | |
| st.session_state.history_container = st.container() | |
| with stylable_container( | |
| key="query_panel", | |
| css_styles=""" | |
| .stTextInput { | |
| position: fixed; | |
| bottom: 0px; | |
| background: white; | |
| z-index: 1000; | |
| padding-bottom: 2rem; | |
| padding-left: 1rem; | |
| padding-right: 1rem; | |
| padding-top: 1rem; | |
| border-top: 1px solid whitesmoke; | |
| height: 8rem; | |
| border-radius: 8px 8px 8px 8px; | |
| box-shadow: 0 -3px 3px whitesmoke; | |
| } | |
| """, | |
| ): | |
| query = st.text_input(key="user_query", | |
| label="Enter your query", | |
| placeholder="Tell me what startups you are looking for", label_visibility="collapsed") | |
| #cluster = st.checkbox("Cluster the results", value = False, key = "cluster") | |
| #prompt_new = st.button("New", on_click = _prompt(prompt_title, prompt)) | |
| tab_advanced = st.sidebar.expander("Settings") | |
| with tab_advanced: | |
| #prompt_title = st.selectbox("Report Type", index = 0, options = utils.get_prompts(), on_change=on_prompt_selected, key="advanced_prompts_select", ) | |
| #prompt_title_editable = st.text_input("Title", key="prompt_title_editable") | |
| report_type = st.selectbox(label="Response Type", options=["gemini", "assistant", "standard", "guided", "company_list", "clustered"], index=0) | |
| #assistant_id = st.text_input(label="Assistant ID", key="assistant_id", value = "asst_NHoxEosVlemDY7y5TYg8ftku") #value="asst_fkZtxo127nxKOCcwrwznuCs2") | |
| assistant_id = st.selectbox(label="OpenAI Assistant", options = [f"{a.id}|||{a.name}" for a in assistants], index = find_default_assistant_idx(assistants)) | |
| #prompt_new = st.button("New", on_click = _prompt(prompt_title, prompt)) | |
| #prompt_delete = st.button("Del", on_click = utils.del_prompt(prompt_title_editable)) | |
| #prompt_save = st.button("Save", on_click = utils.save_prompt(prompt_title_editable, prompt)) | |
| #scrape_boost = st.number_input('Web to API content ratio', value=1.) | |
| top_k = st.number_input('# Top Results', value=30) | |
| is_debug = st.checkbox("Debug output", value = False, key="debug") | |
| openai_model = st.selectbox(label="Model", options=["gpt-4-1106-preview", "gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-16k"], index=0, key="openai_model") | |
| index_namespace = st.selectbox(label="Data Type", options=["websummarized", "web", "cbli", "all"], index=0) | |
| liked_companies = st.text_input(label="liked companies", key='liked_companies') | |
| disliked_companies = st.text_input(label="disliked companies", key='disliked_companies') | |
| default_prompt = st.text_area("Default Prompt", value = utils.default_prompt, height=400, key="advanced_default_prompt_content") | |
| clustering_prompt = st.text_area("Clustering Prompt", value = utils.clustering_prompt, height=400, key="advanced_clustering_prompt_content") | |
| if not "assistant_thread" in st.session_state: | |
| st.session_state.assistant_thread = st.session_state.openai_client.beta.threads.create() | |
| if query != "" and (not 'new_conversation' in st.session_state or not st.session_state.new_conversation): | |
| # if report_type=="standard": | |
| # prompt = default_prompt | |
| # elif report_type=="clustered": | |
| # prompt = clustering_prompt | |
| # elif report_type=="guided": | |
| # prompt = "guided" | |
| # else: | |
| # prompt = "" | |
| #oai.start_conversation() | |
| i = assistant_id.index("|||") | |
| st.session_state.assistant_id = assistant_id[:i] | |
| st.session_state.report_type = report_type | |
| st.session_state.top_k = top_k | |
| st.session_state.index_namespace = index_namespace | |
| st.session_state.region = region_selectbox | |
| st.session_state.country = countries_selectbox | |
| run_query(query, report_type, top_k, region_selectbox, countries_selectbox, is_debug, index_namespace, openai_model, default_prompt) | |
| else: | |
| st.session_state.new_conversation = False | |