Spaces:
Running
Running
| import os | |
| import datetime | |
| import faiss | |
| import streamlit as st | |
| import feedparser | |
| import urllib | |
| import cloudpickle as cp | |
| import pickle | |
| from urllib.request import urlopen | |
| from summa import summarizer | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import requests | |
| import json | |
| from langchain.document_loaders import TextLoader | |
| from langchain.indexes import VectorstoreIndexCreator | |
| from langchain_openai import AzureOpenAIEmbeddings | |
| from langchain.llms import OpenAI | |
| from langchain_openai import AzureChatOpenAI | |
| from langchain import hub | |
| from langchain_core.prompts import PromptTemplate | |
| from langchain_core.runnables import RunnablePassthrough | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.runnables import RunnableParallel | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.vectorstores import Chroma | |
| os.environ["OPENAI_API_TYPE"] = "azure" | |
| os.environ["AZURE_ENDPOINT"] = st.secrets["endpoint1"] | |
| os.environ["OPENAI_API_KEY"] = st.secrets["key1"] | |
| os.environ["OPENAI_API_VERSION"] = "2023-05-15" | |
| embeddings = AzureOpenAIEmbeddings( | |
| deployment="embedding", | |
| model="text-embedding-ada-002", | |
| azure_endpoint=st.secrets["endpoint1"], | |
| ) | |
| llm = AzureChatOpenAI( | |
| deployment_name="gpt4_small", | |
| openai_api_version="2023-12-01-preview", | |
| azure_endpoint=st.secrets["endpoint2"], | |
| openai_api_key=st.secrets["key2"], | |
| openai_api_type="azure", | |
| temperature=0. | |
| ) | |
| def get_feeds_data(url): | |
| # data = cp.load(urlopen(url)) | |
| with open(url, "rb") as fp: | |
| data = pickle.load(fp) | |
| st.sidebar.success("Loaded data") | |
| return data | |
| # feeds_link = "https://drive.google.com/uc?export=download&id=1-IPk1voyUM9VqnghwyVrM1dY6rFnn1S_" | |
| # embed_link = "https://dl.dropboxusercontent.com/s/ob2betm29qrtb8v/astro_ph_ga_feeds_ada_embedding_18-Apr-2023.pkl?dl=0" | |
| dateval = "27-Jun-2023" | |
| feeds_link = "local_files/astro_ph_ga_feeds_upto_"+dateval+".pkl" | |
| embed_link = "local_files/astro_ph_ga_feeds_ada_embedding_"+dateval+".pkl" | |
| gal_feeds = get_feeds_data(feeds_link) | |
| arxiv_ada_embeddings = get_feeds_data(embed_link) | |
| def get_embedding_data(url): | |
| # data = cp.load(urlopen(url)) | |
| with open(url, "rb") as fp: | |
| data = pickle.load(fp) | |
| st.sidebar.success("Fetched data from API!") | |
| return data | |
| # url = "https://drive.google.com/uc?export=download&id=1133tynMwsfdR1wxbkFLhbES3FwDWTPjP" | |
| url = "local_files/astro_ph_ga_embedding_"+dateval+".pkl" | |
| e2d = get_embedding_data(url) | |
| # e2d, _, _, _, _ = get_embedding_data(url) | |
| ctr = -1 | |
| num_chunks = len(gal_feeds) | |
| all_text, all_titles, all_arxivid, all_links, all_authors = [], [], [], [], [] | |
| for nc in range(num_chunks): | |
| for i in range(len(gal_feeds[nc].entries)): | |
| text = gal_feeds[nc].entries[i].summary | |
| text = text.replace('\n', ' ') | |
| text = text.replace('\\', '') | |
| all_text.append(text) | |
| all_titles.append(gal_feeds[nc].entries[i].title) | |
| all_arxivid.append(gal_feeds[nc].entries[i].id.split('/')[-1][0:-2]) | |
| all_links.append(gal_feeds[nc].entries[i].links[1].href) | |
| all_authors.append(gal_feeds[nc].entries[i].authors) | |
| d = arxiv_ada_embeddings.shape[1] # dimension | |
| nb = arxiv_ada_embeddings.shape[0] # database size | |
| xb = arxiv_ada_embeddings.astype('float32') | |
| index = faiss.IndexFlatL2(d) | |
| index.add(xb) | |
| def run_simple_query(search_query = 'all:sed+fitting', max_results = 10, start = 0, sort_by = 'lastUpdatedDate', sort_order = 'descending'): | |
| """ | |
| Query ArXiv to return search results for a particular query | |
| Parameters | |
| ---------- | |
| query: str | |
| query term. use prefixes ti, au, abs, co, jr, cat, m, id, all as applicable. | |
| max_results: int, default = 10 | |
| number of results to return. numbers > 1000 generally lead to timeouts | |
| start: int, default = 0 | |
| start index for results reported. use this if you're interested in running chunks. | |
| Returns | |
| ------- | |
| feed: dict | |
| object containing requested results parsed with feedparser | |
| Notes | |
| ----- | |
| add functionality for chunk parsing, as well as storage and retreival | |
| """ | |
| base_url = 'http://export.arxiv.org/api/query?'; | |
| query = 'search_query=%s&start=%i&max_results=%i&sortBy=%s&sortOrder=%s' % (search_query, | |
| start, | |
| max_results,sort_by,sort_order) | |
| response = urllib.request.urlopen(base_url+query).read() | |
| feed = feedparser.parse(response) | |
| return feed | |
| def find_papers_by_author(auth_name): | |
| doc_ids = [] | |
| for doc_id in range(len(all_authors)): | |
| for auth_id in range(len(all_authors[doc_id])): | |
| if auth_name.lower() in all_authors[doc_id][auth_id]['name'].lower(): | |
| print('Doc ID: ',doc_id, ' | arXiv: ', all_arxivid[doc_id], '| ', all_titles[doc_id],' | Author entry: ', all_authors[doc_id][auth_id]['name']) | |
| doc_ids.append(doc_id) | |
| return doc_ids | |
| def faiss_based_indices(input_vector, nindex=10): | |
| xq = input_vector.reshape(-1,1).T.astype('float32') | |
| D, I = index.search(xq, nindex) | |
| return I[0], D[0] | |
| def list_similar_papers_v2(model_data, | |
| doc_id = [], input_type = 'doc_id', | |
| show_authors = False, show_summary = False, | |
| return_n = 10): | |
| arxiv_ada_embeddings, embeddings, all_titles, all_abstracts, all_authors = model_data | |
| if input_type == 'doc_id': | |
| print('Doc ID: ',doc_id,', title: ',all_titles[doc_id]) | |
| # inferred_vector = model.infer_vector(train_corpus[doc_id].words) | |
| inferred_vector = arxiv_ada_embeddings[doc_id,0:] | |
| start_range = 1 | |
| elif input_type == 'arxiv_id': | |
| print('ArXiv id: ',doc_id) | |
| arxiv_query_feed = run_simple_query(search_query='id:'+str(doc_id)) | |
| if len(arxiv_query_feed.entries) == 0: | |
| print('error: arxiv id not found.') | |
| return | |
| else: | |
| print('Title: '+arxiv_query_feed.entries[0].title) | |
| inferred_vector = np.array(embeddings.embed_query(arxiv_query_feed.entries[0].summary)) | |
| start_range = 0 | |
| elif input_type == 'keywords': | |
| inferred_vector = np.array(embeddings.embed_query(doc_id)) | |
| start_range = 0 | |
| else: | |
| print('unrecognized input type.') | |
| return | |
| sims, dists = faiss_based_indices(inferred_vector, return_n+2) | |
| textstr = '' | |
| abstracts_relevant = [] | |
| fhdrs = [] | |
| for i in range(start_range,start_range+return_n): | |
| abstracts_relevant.append(all_text[sims[i]]) | |
| fhdr = str(sims[i])+'_'+all_authors[sims[i]][0]['name'].split()[-1] + all_arxivid[sims[i]][0:2] +'_'+ all_arxivid[sims[i]] | |
| fhdrs.append(fhdr) | |
| textstr = textstr + str(i+1)+'. **'+ all_titles[sims[i]] +'** (Distance: %.2f' %dists[i]+') \n' | |
| textstr = textstr + '**ArXiv:** ['+all_arxivid[sims[i]]+'](https://arxiv.org/abs/'+all_arxivid[sims[i]]+') \n' | |
| if show_authors == True: | |
| textstr = textstr + '**Authors:** ' | |
| temp = all_authors[sims[i]] | |
| for ak in range(len(temp)): | |
| if ak < len(temp)-1: | |
| textstr = textstr + temp[ak].name + ', ' | |
| else: | |
| textstr = textstr + temp[ak].name + ' \n' | |
| if show_summary == True: | |
| textstr = textstr + '**Summary:** ' | |
| text = all_text[sims[i]] | |
| text = text.replace('\n', ' ') | |
| textstr = textstr + summarizer.summarize(text) + ' \n' | |
| if show_authors == True or show_summary == True: | |
| textstr = textstr + ' ' | |
| textstr = textstr + ' \n' | |
| return textstr, abstracts_relevant, fhdrs, sims | |
| def generate_chat_completion(messages, model="gpt-4", temperature=1, max_tokens=None): | |
| headers = { | |
| "Content-Type": "application/json", | |
| "Authorization": f"Bearer {openai.api_key}", | |
| } | |
| data = { | |
| "model": model, | |
| "messages": messages, | |
| "temperature": temperature, | |
| } | |
| if max_tokens is not None: | |
| data["max_tokens"] = max_tokens | |
| response = requests.post(API_ENDPOINT, headers=headers, data=json.dumps(data)) | |
| if response.status_code == 200: | |
| return response.json()["choices"][0]["message"]["content"] | |
| else: | |
| raise Exception(f"Error {response.status_code}: {response.text}") | |
| model_data = [arxiv_ada_embeddings, embeddings, all_titles, all_text, all_authors] | |
| def format_docs(docs): | |
| return "\n\n".join(doc.page_content for doc in docs) | |
| def get_textstr(i, show_authors=False, show_summary=False): | |
| textstr = '' | |
| textstr = '**'+ all_titles[i] +'** \n' | |
| textstr = textstr + '**ArXiv:** ['+all_arxivid[i]+'](https://arxiv.org/abs/'+all_arxivid[i]+') \n' | |
| if show_authors == True: | |
| textstr = textstr + '**Authors:** ' | |
| temp = all_authors[i] | |
| for ak in range(len(temp)): | |
| if ak < len(temp)-1: | |
| textstr = textstr + temp[ak].name + ', ' | |
| else: | |
| textstr = textstr + temp[ak].name + ' \n' | |
| if show_summary == True: | |
| textstr = textstr + '**Summary:** ' | |
| text = all_text[i] | |
| text = text.replace('\n', ' ') | |
| textstr = textstr + summarizer.summarize(text) + ' \n' | |
| if show_authors == True or show_summary == True: | |
| textstr = textstr + ' ' | |
| textstr = textstr + ' \n' | |
| return textstr | |
| def run_rag(query, return_n = 10, show_authors = True, show_summary = True): | |
| sims, absts, fhdrs, simids = list_similar_papers_v2(model_data, | |
| doc_id = query, | |
| input_type='keywords', | |
| show_authors = show_authors, show_summary = show_summary, | |
| return_n = return_n) | |
| temp_abst = '' | |
| loaders = [] | |
| for i in range(len(absts)): | |
| temp_abst = absts[i] | |
| try: | |
| text_file = open("absts/"+fhdrs[i]+".txt", "w") | |
| except: | |
| os.mkdir('absts') | |
| text_file = open("absts/"+fhdrs[i]+".txt", "w") | |
| n = text_file.write(temp_abst) | |
| text_file.close() | |
| loader = TextLoader("absts/"+fhdrs[i]+".txt") | |
| loaders.append(loader) | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=50) | |
| splits = text_splitter.split_documents([loader.load()[0] for loader in loaders]) | |
| vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings) | |
| retriever = vectorstore.as_retriever() | |
| template = """You are an assistant with expertise in astrophysics for question-answering tasks. | |
| Use the following pieces of retrieved context from the literature to answer the question. | |
| If you don't know the answer, just say that you don't know. | |
| Use six sentences maximum and keep the answer concise. | |
| {context} | |
| Question: {question} | |
| Answer:""" | |
| custom_rag_prompt = PromptTemplate.from_template(template) | |
| rag_chain_from_docs = ( | |
| RunnablePassthrough.assign(context=(lambda x: format_docs(x["context"]))) | |
| | custom_rag_prompt | |
| | llm | |
| | StrOutputParser() | |
| ) | |
| rag_chain_with_source = RunnableParallel( | |
| {"context": retriever, "question": RunnablePassthrough()} | |
| ).assign(answer=rag_chain_from_docs) | |
| rag_answer = rag_chain_with_source.invoke(query) | |
| st.markdown('### User query: '+query) | |
| st.markdown(rag_answer['answer']) | |
| opstr = '#### Primary sources: \n' | |
| srcnames = [] | |
| for i in range(len(rag_answer['context'])): | |
| srcnames.append(rag_answer['context'][0].metadata['source']) | |
| srcnames = np.unique(srcnames) | |
| srcindices = [] | |
| for i in range(len(srcnames)): | |
| temp = srcnames[i].split('_')[1] | |
| srcindices.append(int(srcnames[i].split('_')[0].split('/')[1])) | |
| if int(temp[-2:]) < 40: | |
| temp = temp[0:-2] + ' et al. 20' + temp[-2:] | |
| else: | |
| temp = temp[0:-2] + ' et al. 19' + temp[-2:] | |
| temp = '['+temp+']('+all_links[int(srcnames[i].split('_')[0].split('/')[1])]+')' | |
| st.markdown(temp) | |
| abs_indices = np.array(srcindices) | |
| fig = plt.figure(figsize=(9,9)) | |
| plt.scatter(e2d[0:,0], e2d[0:,1],s=2) | |
| plt.scatter(e2d[simids,0], e2d[simids,1],s=30) | |
| plt.scatter(e2d[abs_indices,0], e2d[abs_indices,1],s=100,color='k',marker='d') | |
| plt.title('localization for question: '+query) | |
| st.pyplot(fig) | |
| st.markdown('\n #### List of relevant papers:') | |
| st.markdown(sims) | |
| return rag_answer | |
| st.title('ArXiv-based question answering') | |
| st.markdown('[Includes papers up to: `'+dateval+'`]') | |
| st.markdown('Concise answers for questions using arxiv abstracts + GPT-4. You might need to wait for a few seconds for the GPT-4 query to return an answer (check top right corner to see if it is still running).') | |
| st.markdown('The answers are followed by relevant source(s) used in the answer, a graph showing which part of the astro-ph.GA manifold it drew the answer from (tightly clustered points generally indicate high quality/consensus answers) followed by a bunch of relevant papers used by the RAG to compose the answer.') | |
| st.markdown('If this does not satisfactorily answer your question or rambles too much, you can also try the older `qa_sources_v1` page.') | |
| query = st.text_input('Your question here:', | |
| value="What causes galaxy quenching at high redshifts?") | |
| return_n = st.slider('How many papers should I show?', 1, 30, 10) | |
| sims = run_rag(query, return_n = return_n) | |