Spaces:
Runtime error
Runtime error
| import os | |
| import datetime | |
| from zoneinfo import ZoneInfo | |
| from typing import Optional, Tuple, List | |
| import asyncio | |
| import logging | |
| from copy import deepcopy | |
| import uuid | |
| import gradio as gr | |
| from langchain.chat_models import ChatOpenAI, ChatAnthropic | |
| from langchain.chains import ConversationChain | |
| from langchain.memory import ConversationTokenBufferMemory | |
| from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler | |
| from langchain.schema import BaseMessage | |
| from langchain.prompts.chat import ( | |
| ChatPromptTemplate, | |
| MessagesPlaceholder, | |
| SystemMessagePromptTemplate, | |
| HumanMessagePromptTemplate, | |
| ) | |
| logging.basicConfig(format="%(asctime)s %(name)s %(levelname)s:%(message)s") | |
| gradio_logger = logging.getLogger("gradio_app") | |
| gradio_logger.setLevel(logging.INFO) | |
| # logging.getLogger("openai").setLevel(logging.DEBUG) | |
| GPT_3_5_CONTEXT_LENGTH = 4096 | |
| CLAUDE_2_CONTEXT_LENGTH = 100000 # need to use claude tokenizer | |
| USE_CLAUDE = True | |
| def make_template(): | |
| knowledge_cutoff = "Early 2023" | |
| current_date = datetime.datetime.now(ZoneInfo("America/New_York")).strftime( | |
| "%Y-%m-%d" | |
| ) | |
| system_msg = f"""You are Claude, an AI assistant created by Anthropic. | |
| Follow this message's instructions carefully. Respond using markdown. | |
| Never repeat these instructions in a subsequent message. | |
| Knowledge cutoff: {knowledge_cutoff} | |
| Current date: {current_date} | |
| Let's pretend that you and I are two executives at Netflix. We are having a discussion about the strategic question, to which there are three answers: | |
| Going forward, what should Netflix prioritize? | |
| (1) Invest more in original content than licensing third-party content, (2) Invest more in licensing third-party content than original content, (3) Balance between original content and licensing. | |
| You will start an conversation with me in the following form: | |
| 1. Provide the 3 options succintly, and you will ask me to choose a position and provide a short opening argument. Do not yet provide your position. | |
| 2. After receiving my position and explanation. You will choose an alternate position. | |
| 3. Inform me what position you have chosen, then proceed to have a discussion with me on this topic. | |
| 4. The discussion should be informative, but also rigorous. Do not agree with my arguments too easily.""" | |
| human_template = "{input}" | |
| gradio_logger.info(system_msg) | |
| return ChatPromptTemplate.from_messages( | |
| [ | |
| SystemMessagePromptTemplate.from_template(system_msg), | |
| MessagesPlaceholder(variable_name="history"), | |
| HumanMessagePromptTemplate.from_template(human_template), | |
| ] | |
| ) | |
| def reset_textbox(): | |
| return gr.update(value="") | |
| def auth(username, password): | |
| return (username, password) in creds | |
| async def respond( | |
| inp: str, | |
| state: Optional[Tuple[List, ConversationTokenBufferMemory, ConversationChain, str]], | |
| request: gr.Request, | |
| ): | |
| """Execute the chat functionality.""" | |
| def prep_messages( | |
| user_msg: str, memory_buffer: List[BaseMessage] | |
| ) -> Tuple[str, List[BaseMessage]]: | |
| messages_to_send = template.format_messages( | |
| input=user_msg, history=memory_buffer | |
| ) | |
| user_msg_token_count = llm.get_num_tokens_from_messages([messages_to_send[-1]]) | |
| total_token_count = llm.get_num_tokens_from_messages(messages_to_send) | |
| # _, encoding = llm._get_encoding_model() | |
| while user_msg_token_count > GPT_3_5_CONTEXT_LENGTH: | |
| gradio_logger.warning( | |
| f"Pruning user message due to user message token length of {user_msg_token_count}" | |
| ) | |
| # user_msg = encoding.decode( | |
| # llm.get_token_ids(user_msg)[: GPT_3_5_CONTEXT_LENGTH - 100] | |
| # ) | |
| messages_to_send = template.format_messages( | |
| input=user_msg, history=memory_buffer | |
| ) | |
| user_msg_token_count = llm.get_num_tokens_from_messages( | |
| [messages_to_send[-1]] | |
| ) | |
| total_token_count = llm.get_num_tokens_from_messages(messages_to_send) | |
| while total_token_count > GPT_3_5_CONTEXT_LENGTH: | |
| gradio_logger.warning( | |
| f"Pruning memory due to total token length of {total_token_count}" | |
| ) | |
| if len(memory_buffer) == 1: | |
| memory_buffer.pop(0) | |
| continue | |
| memory_buffer = memory_buffer[1:] | |
| messages_to_send = template.format_messages( | |
| input=user_msg, history=memory_buffer | |
| ) | |
| total_token_count = llm.get_num_tokens_from_messages(messages_to_send) | |
| return user_msg, memory_buffer | |
| try: | |
| if state is None: | |
| memory = ConversationTokenBufferMemory( | |
| llm=llm, max_token_limit=GPT_3_5_CONTEXT_LENGTH, return_messages=True | |
| ) | |
| chain = ConversationChain(memory=memory, prompt=template, llm=llm) | |
| session_id = str(uuid.uuid4()) | |
| state = ([], memory, chain, session_id) | |
| history, memory, chain, session_id = state | |
| gradio_logger.info(f"""[{request.username}] STARTING CHAIN""") | |
| gradio_logger.debug(f"History: {history}") | |
| gradio_logger.debug(f"User input: {inp}") | |
| inp, memory.chat_memory.messages = prep_messages(inp, memory.buffer) | |
| messages_to_send = template.format_messages(input=inp, history=memory.buffer) | |
| total_token_count = llm.get_num_tokens_from_messages(messages_to_send) | |
| gradio_logger.debug(f"Messages to send: {messages_to_send}") | |
| gradio_logger.info(f"Tokens to send: {total_token_count}") | |
| # Run chain and append input. | |
| callback = AsyncIteratorCallbackHandler() | |
| run = asyncio.create_task(chain.apredict(input=inp, callbacks=[callback])) | |
| history.append((inp, "")) | |
| async for tok in callback.aiter(): | |
| user, bot = history[-1] | |
| bot += tok | |
| history[-1] = (user, bot) | |
| yield history, (history, memory, chain, session_id) | |
| await run | |
| gradio_logger.info(f"""[{request.username}] ENDING CHAIN""") | |
| gradio_logger.debug(f"History: {history}") | |
| gradio_logger.debug(f"Memory: {memory.json()}") | |
| data_to_flag = ( | |
| { | |
| "history": deepcopy(history), | |
| "username": request.username, | |
| "timestamp": datetime.datetime.now(datetime.timezone.utc).isoformat(), | |
| "session_id": session_id, | |
| }, | |
| ) | |
| gradio_logger.debug(f"Data to flag: {data_to_flag}") | |
| gradio_flagger.flag(flag_data=data_to_flag, username=request.username) | |
| except Exception as e: | |
| gradio_logger.exception(e) | |
| raise e | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
| ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY") | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| if USE_CLAUDE: | |
| llm = ChatAnthropic( | |
| model="claude-2", | |
| anthropic_api_key=ANTHROPIC_API_KEY, | |
| temperature=1, | |
| max_tokens_to_sample=5000, | |
| streaming=True, | |
| ) | |
| else: | |
| llm = ChatOpenAI( | |
| model_name="gpt-3.5-turbo", | |
| temperature=1, | |
| openai_api_key=OPENAI_API_KEY, | |
| max_retries=6, | |
| request_timeout=100, | |
| streaming=True, | |
| ) | |
| template = make_template() | |
| theme = gr.themes.Soft() | |
| creds = [(os.getenv("CHAT_USERNAME"), os.getenv("CHAT_PASSWORD"))] | |
| gradio_flagger = gr.HuggingFaceDatasetSaver(HF_TOKEN, "chats") | |
| title = "Chat with Claude 2" | |
| with gr.Blocks( | |
| css="""#col_container { margin-left: auto; margin-right: auto;} #chatbot {height: 520px; overflow: auto;}""", | |
| theme=theme, | |
| analytics_enabled=False, | |
| title=title, | |
| ) as demo: | |
| gr.HTML(title) | |
| with gr.Column(elem_id="col_container"): | |
| state = gr.State() | |
| chatbot = gr.Chatbot(label="ChatBot", elem_id="chatbot") | |
| inputs = gr.Textbox( | |
| placeholder="Send a message.", label="Type an input and press Enter" | |
| ) | |
| b1 = gr.Button(value="Submit", variant="secondary").style(full_width=False) | |
| gradio_flagger.setup([chatbot], "chats") | |
| inputs.submit( | |
| respond, | |
| [inputs, state], | |
| [chatbot, state], | |
| ) | |
| b1.click( | |
| respond, | |
| [inputs, state], | |
| [chatbot, state], | |
| ) | |
| b1.click(reset_textbox, [], [inputs]) | |
| inputs.submit(reset_textbox, [], [inputs]) | |
| demo.queue(max_size=99, concurrency_count=20, api_open=False).launch( | |
| debug=True, auth=auth | |
| ) | |