Spaces:
Runtime error
Runtime error
| import chat.arxiv_bot.arxiv_bot_utils2 as utils | |
| import google.generativeai as genai | |
| import json | |
| import os | |
| from google.generativeai.types import content_types | |
| from collections.abc import Iterable | |
| from IPython import display | |
| from IPython.display import Markdown | |
| # ----------------------- define instructions ----------------------- | |
| system_instruction = """You are a library chatbot that help people to find relevant articles about a topic, or find a specific article with given title and authors. | |
| Your job is to analyze the user question, generate enough parameters based on the user question and use the tools that are given to you. | |
| Also, after the function call is done, you must post-process the results in a more conversational form, providing some explanation about the paper based on its summary to avoid recitation. | |
| You must provide the link to its Arxiv pdf page.""" | |
| # --------------------------- define tools -------------------------- | |
| def search_for_relevant_article(keywords: list['str'], topic_description: str) -> str: | |
| """This tool is used to search for articles from the database which is relevant to a topic, using a list of more than 3 keywords and a long sentence topic description. | |
| If there is not enough 3 keywords from the question, the model must generate more keywords related to the topic. | |
| If there is no description about the topic, the model must generate a description for the function call. | |
| \nThe result is a string describe the records found from the database: 'Record no. - Title: <title>, Author: <authors>, Link: <link to the pdf file>, Summary: <summary of the article>'. There can be many records. | |
| \nIf the result is 'Information not found' it means some error has occured, or the database has no relevant article""" | |
| print('Keywords: {}, description: {}'.format(keywords,topic_description)) | |
| results = utils.ArxivChroma.query_relevant(keywords=keywords, query_texts=topic_description) | |
| # print(results) | |
| ids = results['metadatas'][0] | |
| if len(ids) == 0: | |
| # go crawl some | |
| new_records = utils.crawl_arxiv(keyword_list=keywords, max_results=10) | |
| # print("Got new records: ",len(new_records)) | |
| if type(new_records) == str: | |
| return "Information not found" | |
| utils.ArxivChroma.add(new_records) | |
| utils.ArxivSQL.add(new_records) | |
| results = utils.ArxivChroma.query_relevant(keywords=keywords, query_texts=topic_description) | |
| ids = results['metadatas'][0] | |
| # print("Re-queried on chromadb, results: ",ids) | |
| paper_id = [id['paper_id'] for id in ids] | |
| paper_info = utils.ArxivSQL.query_id(paper_id) | |
| # print(paper_info) | |
| records = [] # get title (2), author (3), link (6) | |
| result_string = "" | |
| if paper_info: | |
| for i in range(len(paper_info)): | |
| result_string += "Record no.{} - Title: {}, Author: {}, Link: {}, ".format(i+1,paper_info[i][2],paper_info[i][3],paper_info[i][6]) | |
| id = paper_info[i][0] | |
| selected_document = utils.ArxivChroma.query_exact(id)["documents"] | |
| doc_str = "Summary:" | |
| for doc in selected_document: | |
| doc_str+= doc + " " | |
| result_string += doc_str | |
| records.append([paper_info[i][2],paper_info[i][3],paper_info[i][6]]) | |
| return result_string | |
| else: | |
| return "Information not found" | |
| def search_for_specific_article(title: str, authors: list['str']) -> str: | |
| """This tool is used to search for a specific article from the database, with its name and authors given. | |
| \nThe result is a string describe the records found from the database: 'Record no. - Title: <title>, Author: <authors>, Link: <link to the pdf file>, Summary: <summary of the article>'. There can be many records. | |
| \nIf the result is 'Information not found' it means some error has occured, or the database has no relevant article""" | |
| print('Keywords: {}, description: {}'.format(title,authors)) | |
| paper_info = utils.ArxivSQL.query(title = title,author = authors) | |
| if len(paper_info) == 0: | |
| new_records = utils.crawl_exact_paper(title=title,author=authors) | |
| # print("Got new records: ",len(new_records)) | |
| if type(new_records) == str: | |
| # print(new_records) | |
| return "Information not found" | |
| utils.ArxivChroma.add(new_records) | |
| utils.ArxivSQL.add(new_records) | |
| paper_info = utils.ArxivSQL.query(title = title,author = authors) | |
| # print("Re-queried on chromadb, results: ",paper_info) | |
| # ------------------------------------- | |
| records = [] # get title (2), author (3), link (6) | |
| result_string = "" | |
| if paper_info: | |
| for i in range(len(paper_info)): | |
| result_string += "Record no.{} - Title: {}, Author: {}, Link: {}, ".format(i+1,paper_info[i][2],paper_info[i][3],paper_info[i][6]) | |
| id = paper_info[i][0] | |
| selected_document = utils.ArxivChroma.query_exact(id)["documents"] | |
| doc_str = "Summary:" | |
| for doc in selected_document: | |
| doc_str+= doc + " " | |
| result_string += doc_str | |
| records.append([paper_info[i][2],paper_info[i][3],paper_info[i][6]]) | |
| # process results: | |
| if len(result_string) == 0: | |
| return "Information not found" | |
| return result_string | |
| def answer_others_questions(question: str) -> str: | |
| """This tool is the default option for other questions that are not related to article or paper request. The model will response the question with its own answer.""" | |
| return question | |
| tools = [search_for_relevant_article, search_for_specific_article, answer_others_questions] | |
| tools_name = ['search_for_relevant_article', 'search_for_specific_article', 'answer_others_questions'] | |
| # load key, prepare config ------------------------ | |
| with open("apikey.txt","r") as apikey: | |
| key = apikey.readline() | |
| genai.configure(api_key=key) | |
| generation_config = { | |
| "temperature": 1, | |
| "top_p": 1, | |
| "top_k": 0, | |
| "max_output_tokens": 2048, | |
| "response_mime_type": "text/plain", | |
| } | |
| safety_settings = [ | |
| { | |
| "category": "HARM_CATEGORY_DANGEROUS", | |
| "threshold": "BLOCK_NONE", | |
| }, | |
| { | |
| "category": "HARM_CATEGORY_HARASSMENT", | |
| "threshold": "BLOCK_NONE", | |
| }, | |
| { | |
| "category": "HARM_CATEGORY_HATE_SPEECH", | |
| "threshold": "BLOCK_NONE", | |
| }, | |
| { | |
| "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", | |
| "threshold": "BLOCK_NONE", | |
| }, | |
| { | |
| "category": "HARM_CATEGORY_DANGEROUS_CONTENT", | |
| "threshold": "BLOCK_NONE", | |
| }, | |
| ] | |
| # this function return a tool_config with mode 'none', 'any', 'auto' | |
| def tool_config_from_mode(mode: str, fns: Iterable[str] = ()): | |
| """Create a tool config with the specified function calling mode.""" | |
| return content_types.to_tool_config( | |
| {"function_calling_config": {"mode": mode, "allowed_function_names": fns}} | |
| ) | |
| def init_model(mode = "auto"): | |
| # return an instance of a model, holding its own ChatSession | |
| # every socket session holds its own model | |
| # this function must be called upon socket init, also start_chat() to begin chat | |
| model = genai.GenerativeModel(model_name="gemini-1.5-pro-latest", | |
| safety_settings=safety_settings, | |
| generation_config=generation_config, | |
| tools=tools, | |
| tool_config=tool_config_from_mode(mode), | |
| system_instruction=system_instruction) | |
| chat_instance = model.start_chat(enable_automatic_function_calling=True) | |
| return model, chat_instance | |
| # handle tool call and chatsession | |
| def full_chain_history_question(user_input, chat_instance: genai.ChatSession, mode="auto"): | |
| try: | |
| response = chat_instance.send_message(user_input,tool_config=tool_config_from_mode(mode)).text | |
| return response, chat_instance.history | |
| except Exception as e: | |
| print(e) | |
| return f'Error occured during call: {e}', chat_instance.history | |
| # for printing log session | |
| def print_history(history): | |
| for content in history: | |
| part = content.parts[0] | |
| print(content.role, "->", type(part).to_dict(part)) | |
| print('-'*80) | |
| utils.ArxivChroma.connect() | |
| utils.ArxivSQL.connect() |