Spaces:
Running
Running
File size: 5,194 Bytes
92feab2 a0e37e2 92feab2 a0e37e2 92feab2 a0e37e2 92feab2 a0e37e2 92feab2 a0e37e2 92feab2 a0e37e2 92feab2 a0e37e2 92feab2 a0e37e2 e867f4c 92feab2 a0e37e2 92feab2 a0e37e2 92feab2 a0e37e2 92feab2 a0e37e2 92feab2 a0e37e2 92feab2 a0e37e2 92feab2 a0e37e2 92feab2 a0e37e2 92feab2 a0e37e2 92feab2 a0e37e2 92feab2 a0e37e2 190ec72 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
from typing import List, Tuple, Dict, TypedDict, Optional, Any
import os
import gradio as gr
from langchain_core.language_models.llms import LLM
from langchain_openai.chat_models import ChatOpenAI
from langchain_aws import ChatBedrock
import boto3
from ask_candid.base.config.rest import OPENAI
from ask_candid.base.config.models import Name2Endpoint
from ask_candid.base.config.data import ALL_INDICES
from ask_candid.utils import format_chat_ag_response
from ask_candid.chat import run_chat
ROOT = os.path.dirname(os.path.abspath(__file__))
BUCKET = "candid-data-science-reporting"
PREFIX = "Assistant"
class LoggedComponents(TypedDict):
context: List[gr.components.Component]
found_helpful: gr.components.Component
will_recommend: gr.components.Component
comments: gr.components.Component
email: gr.components.Component
def select_foundation_model(model_name: str, max_new_tokens: int) -> LLM:
if model_name == "gpt-4o":
llm = ChatOpenAI(
model_name=Name2Endpoint[model_name],
max_tokens=max_new_tokens,
api_key=OPENAI["key"],
temperature=0.0,
streaming=True,
)
elif model_name in {"claude-3.5-haiku", "llama-3.1-70b-instruct", "mistral-large", "mixtral-8x7B"}:
llm = ChatBedrock(
client=boto3.client("bedrock-runtime"),
model=Name2Endpoint[model_name],
max_tokens=max_new_tokens,
temperature=0.0
)
else:
raise gr.Error(f"Base model `{model_name}` is not supported")
return llm
def execute(
thread_id: str,
user_input: Dict[str, Any],
history: List[Dict],
model_name: str,
max_new_tokens: int,
indices: Optional[List[str]] = None,
):
return run_chat(
thread_id=thread_id,
user_input=user_input,
history=history,
llm=select_foundation_model(model_name=model_name, max_new_tokens=max_new_tokens),
indices=indices
)
def build_rag_chat() -> Tuple[LoggedComponents, gr.Blocks]:
with gr.Blocks(theme=gr.themes.Soft(), title="Chat") as demo:
gr.Markdown(
"""
<h1>Ask Candid</h1>
<p>
Please read the <a
href='https://info.candid.org/chatbot-reference-guide'
target="_blank"
rel="noopener noreferrer"
>guide</a> to get started.
</p>
<hr>
"""
)
with gr.Accordion(label="Advanced settings", open=False):
es_indices = gr.CheckboxGroup(
choices=list(ALL_INDICES),
value=list(ALL_INDICES),
label="Sources to include",
interactive=True,
)
llmname = gr.Radio(
label="Language model",
value="gpt-4o",
choices=list(Name2Endpoint.keys()),
interactive=True,
)
max_new_tokens = gr.Slider(
value=256 * 3,
minimum=128,
maximum=2048,
step=128,
label="Max new tokens",
interactive=True,
)
with gr.Column():
chatbot = gr.Chatbot(
label="AskCandid",
elem_id="chatbot",
bubble_full_width=True,
avatar_images=(
None,
os.path.join(ROOT, "static", "candid_logo_yellow.png"),
),
height="45vh",
type="messages",
show_label=False,
show_copy_button=True,
show_share_button=True,
show_copy_all_button=True,
)
msg = gr.MultimodalTextbox(label="Your message", interactive=True)
thread_id = gr.Text(visible=False, value="", label="thread_id")
gr.ClearButton(components=[msg, chatbot, thread_id], size="sm")
# pylint: disable=no-member
chat_msg = msg.submit(
fn=execute,
inputs=[thread_id, msg, chatbot, llmname, max_new_tokens, es_indices],
outputs=[msg, chatbot, thread_id],
)
chat_msg.then(format_chat_ag_response, chatbot, chatbot, api_name="bot_response")
logged = LoggedComponents(context=[thread_id, chatbot])
return logged, demo
def build_app():
_, candid_chat = build_rag_chat()
with open(os.path.join(ROOT, "static", "chatStyle.css"), "r", encoding="utf8") as f:
css_chat = f.read()
demo = gr.TabbedInterface(
interface_list=[
candid_chat,
],
tab_names=[
"AskCandid",
],
theme=gr.themes.Soft(),
css=css_chat,
)
return demo
if __name__ == "__main__":
app = build_app()
app.queue(max_size=5).launch(
show_api=False,
auth=[
(os.getenv("APP_USERNAME"), os.getenv("APP_PASSWORD")),
(os.getenv("APP_PUBLIC_USERNAME"), os.getenv("APP_PUBLIC_PASSWORD")),
],
auth_message="Login to Candid's AI assistant",
ssr_mode=False
)
|