Spaces:
Build error
Build error
| import openai | |
| import streamlit_scrollable_textbox as stx | |
| import pinecone | |
| import streamlit as st | |
| st.set_page_config(layout="wide") # isort: split | |
| from utils.entity_extraction import ( | |
| clean_entities, | |
| extract_quarter_year, | |
| extract_ticker_spacy, | |
| format_entities_flan_alpaca, | |
| generate_alpaca_ner_prompt, | |
| ) | |
| from utils.models import ( | |
| generate_entities_flan_alpaca_checkpoint, | |
| generate_entities_flan_alpaca_inference_api, | |
| generate_text_flan_t5, | |
| get_data, | |
| get_flan_alpaca_xl_model, | |
| get_flan_t5_model, | |
| get_mpnet_embedding_model, | |
| get_sgpt_embedding_model, | |
| get_spacy_model, | |
| get_splade_sparse_embedding_model, | |
| get_t5_model, | |
| gpt_model, | |
| save_key, | |
| ) | |
| from utils.prompts import ( | |
| generate_flant5_prompt_instruct_chunk_context, | |
| generate_flant5_prompt_instruct_chunk_context_single, | |
| generate_flant5_prompt_instruct_complete_context, | |
| generate_flant5_prompt_summ_chunk_context, | |
| generate_flant5_prompt_summ_chunk_context_single, | |
| generate_gpt_j_two_shot_prompt_1, | |
| generate_gpt_j_two_shot_prompt_2, | |
| generate_gpt_prompt, | |
| generate_gpt_prompt_2, | |
| get_context_list_prompt, | |
| ) | |
| from utils.retriever import ( | |
| format_query, | |
| query_pinecone, | |
| query_pinecone_sparse, | |
| sentence_id_combine, | |
| text_lookup, | |
| ) | |
| from utils.transcript_retrieval import retrieve_transcript | |
| from utils.vector_index import ( | |
| create_dense_embeddings, | |
| create_sparse_embeddings, | |
| hybrid_score_norm, | |
| ) | |
| st.title("Question Answering on Earnings Call Transcripts") | |
| st.write( | |
| "The app uses the quarterly earnings call transcripts for 10 companies (Apple, AMD, Amazon, Cisco, Google, Microsoft, Nvidia, ASML, Intel, Micron) for the years 2016 to 2020." | |
| ) | |
| col1, col2 = st.columns([3, 3], gap="medium") | |
| with st.sidebar: | |
| ner_choice = st.selectbox("Select NER Model", ["Spacy", "Alpaca"]) | |
| if ner_choice == "Spacy": | |
| ner_model = get_spacy_model() | |
| with col1: | |
| st.subheader("Question") | |
| query_text = st.text_area( | |
| "Input Query", | |
| value="What was discussed regarding Wearables revenue performance?", | |
| ) | |
| if ner_choice == "Alpaca": | |
| ner_prompt = generate_alpaca_ner_prompt(query_text) | |
| entity_text = generate_entities_flan_alpaca_inference_api(ner_prompt) | |
| company_ent, quarter_ent, year_ent = format_entities_flan_alpaca( | |
| entity_text | |
| ) | |
| else: | |
| company_ent = extract_ticker_spacy(query_text, ner_model) | |
| quarter_ent, year_ent = extract_quarter_year(query_text) | |
| ticker_index, quarter_index, year_index = clean_entities( | |
| company_ent, quarter_ent, year_ent | |
| ) | |
| with col1: | |
| years_choice = ["2020", "2019", "2018", "2017", "2016", "All"] | |
| with col1: | |
| # Hardcoding the defaults for a question without metadata | |
| if ( | |
| query_text | |
| == "What was discussed regarding Wearables revenue performance?" | |
| ): | |
| year = st.selectbox("Year", years_choice) | |
| else: | |
| year = st.selectbox("Year", years_choice, index=year_index) | |
| with col1: | |
| # Hardcoding the defaults for a question without metadata | |
| if ( | |
| query_text | |
| == "What was discussed regarding Wearables revenue performance?" | |
| ): | |
| quarter = st.selectbox("Quarter", ["Q1", "Q2", "Q3", "Q4", "All"]) | |
| else: | |
| quarter = st.selectbox( | |
| "Quarter", ["Q1", "Q2", "Q3", "Q4", "All"], index=quarter_index | |
| ) | |
| with col1: | |
| participant_type = st.selectbox("Speaker", ["Company Speaker", "Analyst"]) | |
| ticker_choice = [ | |
| "AAPL", | |
| "CSCO", | |
| "MSFT", | |
| "ASML", | |
| "NVDA", | |
| "GOOGL", | |
| "MU", | |
| "INTC", | |
| "AMZN", | |
| "AMD", | |
| ] | |
| with col1: | |
| # Hardcoding the defaults for a question without metadata | |
| if ( | |
| query_text | |
| == "What was discussed regarding Wearables revenue performance?" | |
| ): | |
| ticker = st.selectbox("Company", ticker_choice) | |
| else: | |
| ticker = st.selectbox("Company", ticker_choice, ticker_index) | |
| with st.sidebar: | |
| st.subheader("Select Options:") | |
| with st.sidebar: | |
| num_results = int( | |
| st.number_input("Number of Results to query", 1, 15, value=5) | |
| ) | |
| # Choose encoder model | |
| encoder_models_choice = ["MPNET", "SGPT", "Hybrid MPNET - SPLADE"] | |
| with st.sidebar: | |
| encoder_model = st.selectbox("Select Encoder Model", encoder_models_choice) | |
| # Choose decoder model | |
| decoder_models_choice = ["GPT3 - (text-davinci-003)", "T5", "FLAN-T5", "GPT-J"] | |
| with st.sidebar: | |
| decoder_model = st.selectbox("Select Decoder Model", decoder_models_choice) | |
| if encoder_model == "MPNET": | |
| # Connect to pinecone environment | |
| pinecone.init( | |
| api_key=st.secrets["pinecone_mpnet"], environment="us-east1-gcp" | |
| ) | |
| pinecone_index_name = "week2-all-mpnet-base" | |
| pinecone_index = pinecone.Index(pinecone_index_name) | |
| retriever_model = get_mpnet_embedding_model() | |
| elif encoder_model == "SGPT": | |
| # Connect to pinecone environment | |
| pinecone.init( | |
| api_key=st.secrets["pinecone_sgpt"], environment="us-east1-gcp" | |
| ) | |
| pinecone_index_name = "week2-sgpt-125m" | |
| pinecone_index = pinecone.Index(pinecone_index_name) | |
| retriever_model = get_sgpt_embedding_model() | |
| elif encoder_model == "Hybrid MPNET - SPLADE": | |
| pinecone.init( | |
| api_key=st.secrets["pinecone_hybrid_splade_mpnet"], | |
| environment="us-central1-gcp", | |
| ) | |
| pinecone_index_name = "splade-mpnet" | |
| pinecone_index = pinecone.Index(pinecone_index_name) | |
| retriever_model = get_mpnet_embedding_model() | |
| ( | |
| sparse_retriever_model, | |
| sparse_retriever_tokenizer, | |
| ) = get_splade_sparse_embedding_model() | |
| with st.sidebar: | |
| window = int(st.number_input("Sentence Window Size", 0, 10, value=1)) | |
| with st.sidebar: | |
| threshold = float( | |
| st.number_input( | |
| label="Similarity Score Threshold", | |
| step=0.05, | |
| format="%.2f", | |
| value=0.25, | |
| ) | |
| ) | |
| data = get_data() | |
| if encoder_model == "Hybrid SGPT - SPLADE": | |
| dense_query_embedding = create_dense_embeddings( | |
| query_text, retriever_model | |
| ) | |
| sparse_query_embedding = create_sparse_embeddings( | |
| query_text, sparse_retriever_model, sparse_retriever_tokenizer | |
| ) | |
| dense_query_embedding, sparse_query_embedding = hybrid_score_norm( | |
| dense_query_embedding, sparse_query_embedding, 0 | |
| ) | |
| query_results = query_pinecone_sparse( | |
| dense_query_embedding, | |
| sparse_query_embedding, | |
| num_results, | |
| pinecone_index, | |
| year, | |
| quarter, | |
| ticker, | |
| participant_type, | |
| threshold, | |
| ) | |
| else: | |
| dense_query_embedding = create_dense_embeddings( | |
| query_text, retriever_model | |
| ) | |
| query_results = query_pinecone( | |
| dense_query_embedding, | |
| num_results, | |
| pinecone_index, | |
| year, | |
| quarter, | |
| ticker, | |
| participant_type, | |
| threshold, | |
| ) | |
| if threshold <= 0.90: | |
| context_list = sentence_id_combine(data, query_results, lag=window) | |
| else: | |
| context_list = format_query(query_results) | |
| if decoder_model == "GPT3 - (text-davinci-003)": | |
| prompt = generate_gpt_prompt(query_text, context_list) | |
| with col2: | |
| with st.form("my_form"): | |
| edited_prompt = st.text_area( | |
| label="Model Prompt", value=prompt, height=270 | |
| ) | |
| openai_key = st.text_input( | |
| "Enter OpenAI key", | |
| value="", | |
| type="password", | |
| ) | |
| submitted = st.form_submit_button("Submit") | |
| if submitted: | |
| api_key = save_key(openai_key) | |
| openai.api_key = api_key | |
| generated_text = gpt_model(edited_prompt) | |
| st.subheader("Answer:") | |
| st.write(generated_text) | |
| elif decoder_model == "T5": | |
| prompt = generate_flant5_prompt_instruct_complete_context( | |
| query_text, context_list | |
| ) | |
| t5_pipeline = get_t5_model() | |
| output_text = [] | |
| with col2: | |
| with st.form("my_form"): | |
| edited_prompt = st.text_area( | |
| label="Model Prompt", value=prompt, height=270 | |
| ) | |
| context_list = get_context_list_prompt(edited_prompt) | |
| submitted = st.form_submit_button("Submit") | |
| if submitted: | |
| for context_text in context_list: | |
| output_text.append( | |
| t5_pipeline(context_text)[0]["summary_text"] | |
| ) | |
| st.subheader("Answer:") | |
| for text in output_text: | |
| st.markdown(f"- {text}") | |
| elif decoder_model == "FLAN-T5": | |
| flan_t5_model, flan_t5_tokenizer = get_flan_t5_model() | |
| output_text = [] | |
| with col2: | |
| prompt_type = st.selectbox( | |
| "Select prompt type", | |
| ["Complete Text QA", "Chunkwise QA", "Chunkwise Summarize"], | |
| ) | |
| if prompt_type == "Complete Text QA": | |
| prompt = generate_flant5_prompt_instruct_complete_context( | |
| query_text, context_list | |
| ) | |
| elif prompt_type == "Chunkwise QA": | |
| st.write("The following prompt is not editable.") | |
| prompt = generate_flant5_prompt_instruct_chunk_context( | |
| query_text, context_list | |
| ) | |
| elif prompt_type == "Chunkwise Summarize": | |
| st.write("The following prompt is not editable.") | |
| prompt = generate_flant5_prompt_summ_chunk_context( | |
| query_text, context_list | |
| ) | |
| else: | |
| prompt = "" | |
| with st.form("my_form"): | |
| edited_prompt = st.text_area( | |
| label="Model Prompt", value=prompt, height=270 | |
| ) | |
| submitted = st.form_submit_button("Submit") | |
| if submitted: | |
| if prompt_type == "Complete Text QA": | |
| output_text_string = generate_text_flan_t5( | |
| flan_t5_model, flan_t5_tokenizer, prompt | |
| ) | |
| st.subheader("Answer:") | |
| st.write(output_text_string) | |
| elif prompt_type == "Chunkwise QA": | |
| for context_text in context_list: | |
| model_input = generate_flant5_prompt_instruct_chunk_context_single( | |
| query_text, context_text | |
| ) | |
| output_text.append( | |
| generate_text_flan_t5( | |
| flan_t5_model, flan_t5_tokenizer, model_input | |
| ) | |
| ) | |
| st.subheader("Answer:") | |
| for text in output_text: | |
| if "(iii)" not in text: | |
| st.markdown(f"- {text}") | |
| elif prompt_type == "Chunkwise Summarize": | |
| for context_text in context_list: | |
| model_input = ( | |
| generate_flant5_prompt_summ_chunk_context_single( | |
| query_text, context_text | |
| ) | |
| ) | |
| output_text.append( | |
| generate_text_flan_t5( | |
| flan_t5_model, flan_t5_tokenizer, model_input | |
| ) | |
| ) | |
| st.subheader("Answer:") | |
| for text in output_text: | |
| if "(iii)" not in text: | |
| st.markdown(f"- {text}") | |
| if decoder_model == "GPT-J": | |
| if ticker in ["AAPL", "AMD"]: | |
| prompt = generate_gpt_j_two_shot_prompt_1(query_text, context_list) | |
| elif ticker in ["NVDA", "INTC", "AMZN"]: | |
| prompt = generate_gpt_j_two_shot_prompt_2(query_text, context_list) | |
| else: | |
| prompt = generate_gpt_j_two_shot_prompt_1(query_text, context_list) | |
| with col2: | |
| with st.form("my_form"): | |
| edited_prompt = st.text_area( | |
| label="Model Prompt", value=prompt, height=270 | |
| ) | |
| st.write( | |
| "The app currently just shows the prompt. The app does not load the model due to memory limitations." | |
| ) | |
| submitted = st.form_submit_button("Submit") | |
| with col1: | |
| with st.expander("See Retrieved Text"): | |
| st.subheader("Retrieved Text:") | |
| for context_text in context_list: | |
| context_text = f"""{context_text}""" | |
| st.write( | |
| f"<ul><li><p>{context_text}</p></li></ul>", | |
| unsafe_allow_html=True, | |
| ) | |
| file_text = retrieve_transcript(data, year, quarter, ticker) | |
| with col1: | |
| with st.expander("See Transcript"): | |
| st.subheader("Earnings Call Transcript:") | |
| stx.scrollableTextbox( | |
| file_text, height=700, border=False, fontFamily="Helvetica" | |
| ) | |