|
import gradio as gr |
|
import requests |
|
import json |
|
import os |
|
|
|
|
|
API_KEY = os.getenv('API_KEY') |
|
INVOKE_URL = "https://api.nvcf.nvidia.com/v2/nvcf/pexec/functions/df2bee43-fb69-42b9-9ee5-f4eabbeaf3a8" |
|
|
|
headers = { |
|
"Authorization": f"Bearer {API_KEY}", |
|
"accept": "text/event-stream", |
|
"content-type": "application/json", |
|
} |
|
|
|
BASE_SYSTEM_MESSAGE = "I carefully provide accurate, factual, thoughtful, nuanced answers and am brilliant at reasoning." |
|
|
|
def clear_chat(chat_history_state, chat_message): |
|
chat_history_state = [] |
|
chat_message = '' |
|
return chat_history_state, chat_message |
|
|
|
def user(message, history): |
|
history = history or [] |
|
history.append({"role": "user", "content": message}) |
|
return "", history |
|
|
|
def call_api(history, max_tokens, temperature, top_p, seed=42): |
|
payload = { |
|
"messages": history, |
|
"temperature": temperature, |
|
"top_p": top_p, |
|
"max_tokens": max_tokens, |
|
"seed": seed, |
|
"stream": True |
|
} |
|
response = requests.post(INVOKE_URL, headers=headers, json=payload, stream=True) |
|
full_response = "" |
|
for line in response.iter_lines(): |
|
if line: |
|
decoded_line = line.decode("utf-8") |
|
if "data:" in decoded_line: |
|
json_data = json.loads(decoded_line.replace("data:", "")) |
|
if "choices" in json_data and len(json_data["choices"]) > 0: |
|
deltas = json_data["choices"][0].get("delta", {}) |
|
if "content" in deltas: |
|
full_response += deltas["content"] |
|
return full_response |
|
|
|
def chat(history, system_message, max_tokens, temperature, top_p, top_k, repetition_penalty): |
|
system_message_to_use = system_message if system_message.strip() else BASE_SYSTEM_MESSAGE |
|
if history and "role" in history[-1] and history[-1]["role"] == "user": |
|
history.append({"role": "system", "content": system_message_to_use}) |
|
assistant_response = call_api(history, max_tokens, temperature, top_p) |
|
if assistant_response: |
|
history.append({"role": "assistant", "content": assistant_response}) |
|
return history, "", "" |
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("## Your Chatbot Interface") |
|
chatbot = gr.Chatbot() |
|
message = gr.Textbox(label="What do you want to chat about?", placeholder="Ask me anything.", lines=3) |
|
submit = gr.Button(value="Send message") |
|
clear = gr.Button(value="New topic") |
|
system_msg = gr.Textbox(BASE_SYSTEM_MESSAGE, label="System Message", placeholder="System prompt.", lines=5) |
|
max_tokens = gr.Slider(20, 512, label="Max Tokens", step=20, value=500) |
|
temperature = gr.Slider(0.0, 1.0, label="Temperature", step=0.1, value=0.7) |
|
top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.95) |
|
chat_history_state = gr.State([]) |
|
|
|
def update_chatbot(message, chat_history): |
|
_, chat_history, _ = user(message, chat_history) |
|
chat_history, _, _ = chat(chat_history, system_msg.value, max_tokens.value, temperature.value, top_p.value, 40, 1.1) |
|
return chat_history, chat_history, "" |
|
|
|
submit.click( |
|
fn=update_chatbot, |
|
inputs=[message, chat_history_state], |
|
outputs=[chatbot, chat_history_state, message] |
|
) |
|
|
|
clear.click( |
|
fn=lambda: clear_chat([], ""), |
|
inputs=[chat_history_state, message], |
|
outputs=[chat_history_state, message] |
|
) |
|
|
|
demo.launch() |