Spaces:
Running
Running
| import logging | |
| import os | |
| from typing import Optional | |
| import time | |
| import gradio as gr | |
| import pandas as pd | |
| from buster.completers import Completion | |
| from gradio.themes.utils import ( | |
| colors, | |
| fonts, | |
| get_matching_version, | |
| get_theme_assets, | |
| sizes, | |
| ) | |
| import cfg | |
| from cfg import setup_buster | |
| CONCURRENCY_COUNT = int(os.getenv("CONCURRENCY_COUNT", 64)) | |
| AVAILABLE_SOURCES_UI = [ | |
| "Gen AI 360: LLMs", | |
| "Gen AI 360: Advanced RAG", | |
| "Gen AI 360: LangChain", | |
| "Towards AI Blog", | |
| "Activeloop Docs", | |
| "HF Transformers Docs", | |
| "Wikipedia", | |
| "OpenAI Docs", | |
| "LangChain Docs", | |
| ] | |
| AVAILABLE_SOURCES = [ | |
| "llm_course", | |
| "langchain_course", | |
| "advanced_rag_course", | |
| "towards_ai", | |
| "activeloop", | |
| "hf_transformers", | |
| "wikipedia", | |
| "openai", | |
| "langchain_docs", | |
| ] | |
| buster = setup_buster(cfg.buster_cfg) | |
| # suppress httpx logs they are spammy and uninformative | |
| logging.getLogger("httpx").setLevel(logging.WARNING) | |
| logger = logging.getLogger(__name__) | |
| logging.basicConfig(level=logging.INFO) | |
| def save_completion(completion: Completion, question: gr.Textbox): | |
| collection = "completion_data-hf" | |
| completion_json = completion.to_json( | |
| columns_to_ignore=["embedding", "similarity", "similarity_to_answer"] | |
| ) | |
| try: | |
| cfg.mongo_db[collection].insert_one(completion_json) | |
| logger.info("Completion saved to db") | |
| except: | |
| logger.info("Something went wrong logging completion to db") | |
| def log_likes(completion: Completion, like_data: gr.LikeData): | |
| collection = "liked_data-test" | |
| completion_json = completion.to_json( | |
| columns_to_ignore=["embedding", "similarity", "similarity_to_answer"] | |
| ) | |
| completion_json["liked"] = like_data.liked | |
| logger.info(f"User reported {like_data.liked=}") | |
| try: | |
| cfg.mongo_db[collection].insert_one(completion_json) | |
| logger.info("") | |
| except: | |
| logger.info("Something went wrong logging") | |
| def log_emails(email: gr.Textbox): | |
| collection = "email_data-test" | |
| logger.info(f"User reported {email=}") | |
| email_document = {"email": email} | |
| try: | |
| cfg.mongo_db[collection].insert_one(email_document) | |
| logger.info("") | |
| except: | |
| logger.info("Something went wrong logging") | |
| return "" | |
| def format_sources(matched_documents: pd.DataFrame) -> str: | |
| if len(matched_documents) == 0: | |
| return "" | |
| documents_answer_template: str = "📝 Here are the sources I used to answer your question:\n\n{documents}\n\n{footnote}" | |
| document_template: str = "[🔗 {document.source}: {document.title}]({document.url}), relevance: {document.similarity_to_answer:2.1f} %" # | # total chunks matched: {document.repetition:d}" | |
| matched_documents.similarity_to_answer = ( | |
| matched_documents.similarity_to_answer * 100 | |
| ) | |
| matched_documents = matched_documents.sort_values( | |
| "similarity_to_answer", ascending=False | |
| ).drop_duplicates("title", keep="first") | |
| display_source_to_ui = { | |
| ui: src for ui, src in zip(AVAILABLE_SOURCES, AVAILABLE_SOURCES_UI) | |
| } | |
| matched_documents["source"] = matched_documents["source"].replace( | |
| display_source_to_ui | |
| ) | |
| documents = "\n".join( | |
| [ | |
| document_template.format(document=document) | |
| for _, document in matched_documents.iterrows() | |
| ] | |
| ) | |
| footnote: str = "I'm a bot 🤖 and not always perfect." | |
| return documents_answer_template.format(documents=documents, footnote=footnote) | |
| def add_sources(history, completion): | |
| if completion.answer_relevant: | |
| formatted_sources = format_sources(completion.matched_documents) | |
| history.append([None, formatted_sources]) | |
| return history | |
| def user(user_input, history): | |
| """Adds user's question immediately to the chat.""" | |
| return "", history + [[user_input, None]] | |
| def get_empty_source_completion(user_input): | |
| return Completion( | |
| user_inputs=user_input, | |
| answer_text="You have to select at least one source from the dropdown menu.", | |
| matched_documents=pd.DataFrame(), | |
| error=False, | |
| ) | |
| def get_answer(history, sources: Optional[list[str]] = None): | |
| user_input = history[-1][0] | |
| if len(sources) == 0: | |
| completion = get_empty_source_completion(user_input) | |
| else: | |
| # Go to code names | |
| display_ui_to_source = { | |
| ui: src for ui, src in zip(AVAILABLE_SOURCES_UI, AVAILABLE_SOURCES) | |
| } | |
| sources_renamed = [display_ui_to_source[disp] for disp in sources] | |
| completion = buster.process_input(user_input, sources=sources_renamed) | |
| history[-1][1] = "" | |
| for token in completion.answer_generator: | |
| history[-1][1] += token | |
| yield history, completion | |
| theme = gr.themes.Soft() | |
| with gr.Blocks( | |
| theme=gr.themes.Soft( | |
| primary_hue="blue", | |
| secondary_hue="blue", | |
| font=fonts.GoogleFont("Source Sans Pro"), | |
| font_mono=fonts.GoogleFont("IBM Plex Mono"), | |
| ) | |
| ) as demo: | |
| with gr.Row(): | |
| gr.Markdown( | |
| "<h3><center>Towards AI 🤖: A Question-Answering Bot for anything AI-related</center></h3>" | |
| "<h6><center><i>Powered by Activeloop and 4th Generation Intel® Xeon® Scalable Processors</i></center></h6>" | |
| ) | |
| latest_completion = gr.State() | |
| source_selection = gr.Dropdown( | |
| choices=AVAILABLE_SOURCES_UI, | |
| label="Select Sources", | |
| value=AVAILABLE_SOURCES_UI, | |
| multiselect=True, | |
| ) | |
| chatbot = gr.Chatbot(elem_id="chatbot", show_copy_button=True) | |
| with gr.Row(): | |
| question = gr.Textbox( | |
| label="What's your question?", | |
| placeholder="Ask a question to our AI tutor here...", | |
| lines=1, | |
| ) | |
| submit = gr.Button(value="Send", variant="secondary") | |
| with gr.Row(): | |
| examples = gr.Examples( | |
| examples=cfg.example_questions, | |
| inputs=question, | |
| ) | |
| with gr.Row(): | |
| email = gr.Textbox( | |
| label="Want to receive updates about our AI tutor?", | |
| placeholder="Enter your email here...", | |
| lines=1, | |
| scale=3, | |
| ) | |
| submit_email = gr.Button(value="Submit", variant="secondary", scale=0) | |
| gr.Markdown( | |
| "This application uses ChatGPT to search the docs for relevant information and answer questions." | |
| "\n\n### Built in top of the open-source [Buster 🤖](https://www.github.com/jerpint/buster) project. Huge thanks to them." | |
| ) | |
| completion = gr.State() | |
| submit.click(user, [question, chatbot], [question, chatbot], queue=False).then( | |
| get_answer, inputs=[chatbot, source_selection], outputs=[chatbot, completion] | |
| ).then(add_sources, inputs=[chatbot, completion], outputs=[chatbot]).then( | |
| save_completion, inputs=[completion] | |
| ) | |
| question.submit(user, [question, chatbot], [question, chatbot], queue=False).then( | |
| get_answer, inputs=[chatbot, source_selection], outputs=[chatbot, completion] | |
| ).then(add_sources, inputs=[chatbot, completion], outputs=[chatbot]).then( | |
| save_completion, inputs=[completion] | |
| ) | |
| chatbot.like(log_likes, completion) | |
| submit_email.click(log_emails, email, email) | |
| email.submit(log_emails, email, email) | |
| demo.queue(concurrency_count=CONCURRENCY_COUNT) | |
| demo.launch(debug=True, share=False) | |