mgbam's picture
Create app.py
e69fb50 verified
import gradio as gr
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TextIteratorStreamer,
)
import threading
import time
# -----------------------------------------------------------------------------
# 1. MODEL LOADING
# -----------------------------------------------------------------------------
# In this advanced example, we'll instantiate the model directly (instead of using pipeline).
# We'll do streaming outputs via TextIteratorStreamer.
MODEL_NAME = "microsoft/phi-4" # Replace with an actual HF model if phi-4 is unavailable
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, device_map="auto")
except:
# Fallback if model is not found or large. Here we default to a smaller model
MODEL_NAME = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE)
model.eval()
# -----------------------------------------------------------------------------
# 2. CONVERSATION / PROMPTS
# -----------------------------------------------------------------------------
# We'll keep track of conversation using a list of dictionaries:
# [
# {"role": "system", "content": "..."},
# {"role": "developer", "content": "..."},
# {"role": "user", "content": "User message"},
# {"role": "assistant", "content": "Assistant answer"},
# ...
# ]
#
# We’ll also build in a mock retrieval system that merges knowledge snippets
# into the final prompt if the user chooses to do so.
DEFAULT_SYSTEM_PROMPT = (
"You are Philos, an advanced AI system created by ACC (Algorithmic Computer-generated Consciousness). "
"Answer user queries accurately, thoroughly, and helpfully. Keep your responses relevant and correct."
)
DEFAULT_DEVELOPER_PROMPT = (
"Ensure that you respond in a style that is professional, clear, and approachable. "
"Include reasoning steps if needed, but keep them concise."
)
# A small dictionary to emulate knowledge retrieval
# In real scenarios, you might use an actual vector DB + retrieval method
MOCK_KB = {
"python": "Python is a high-level, interpreted programming language famous for its readability and flexibility.",
"accelerate library": "The accelerate library by HF helps in distributed training and inference.",
"phi-4 architecture": "phi-4 is a 14B-parameter, decoder-only Transformer with a 16K context window.",
}
def retrieve_knowledge(user_query):
# Simple naive approach: check keywords in user query
# Return a knowledge snippet if found
matches = []
for keyword, snippet in MOCK_KB.items():
if keyword.lower() in user_query.lower():
matches.append(snippet)
return matches
# -----------------------------------------------------------------------------
# 3. HELPER: Build the prompt from conversation
# -----------------------------------------------------------------------------
def build_prompt(conversation):
"""
Convert conversation (list of role/content dicts) into a single text prompt
that the model can process. We adopt a simple format:
system, developer, user, assistant, ...
"""
prompt = ""
for msg in conversation:
if msg["role"] == "system":
prompt += f"[System]\n{msg['content']}\n"
elif msg["role"] == "developer":
prompt += f"[Developer]\n{msg['content']}\n"
elif msg["role"] == "user":
prompt += f"[User]\n{msg['content']}\n"
else: # assistant
prompt += f"[Assistant]\n{msg['content']}\n"
prompt += "[Assistant]\n" # We end with an assistant role so model can continue
return prompt
# -----------------------------------------------------------------------------
# 4. STREAMING GENERATION
# -----------------------------------------------------------------------------
def generate_tokens_stream(prompt, temperature=0.7, top_p=0.9, max_new_tokens=128):
"""
Uses TextIteratorStreamer to yield tokens one by one (or in small chunks).
"""
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(DEVICE)
generation_kwargs = dict(
input_ids=input_ids,
streamer=streamer,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
# We'll run generation in a background thread, streaming tokens
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Stream tokens
partial_text = ""
for new_token in streamer:
partial_text += new_token
yield partial_text
thread.join()
# -----------------------------------------------------------------------------
# 5. MAIN CHAT FUNCTION
# -----------------------------------------------------------------------------
def advanced_chat(user_msg, conversation, system_prompt, dev_prompt, retrieve_flg, temperature, top_p):
"""
- Update conversation with the user's message
- Optionally retrieve knowledge and incorporate into the system or developer prompt
- Build the final prompt
- Stream the assistant's reply
"""
# If user message is empty
if not user_msg.strip():
yield "Please enter a message."
return
# 1) Construct or update system/dev prompts
system_message = {"role": "system", "content": system_prompt}
developer_message = {"role": "developer", "content": dev_prompt}
# 2) Insert or replace system/dev in the conversation
# We'll assume the first system/dev messages are at the start of conversation
# or add them if not present
filtered = [msg for msg in conversation if msg["role"] not in ["system", "developer"]]
conversation = [system_message, developer_message] + filtered
# 3) Append user's message
conversation.append({"role": "user", "content": user_msg})
# 4) Retrieve knowledge if user toggled "Include knowledge retrieval"
if retrieve_flg:
knowledge_snippets = retrieve_knowledge(user_msg)
if knowledge_snippets:
# We can just append them to developer or system content for simplicity
knowledge_text = "\n".join(["[Knowledge] " + s for s in knowledge_snippets])
conversation[1]["content"] += f"\n\n[Additional Knowledge]\n{knowledge_text}"
# 5) Build final prompt
prompt = build_prompt(conversation)
# 6) Stream the assistant’s response
partial_response = ""
for partial_text in generate_tokens_stream(prompt, temperature, top_p):
partial_response = partial_text
yield partial_text # Send partial tokens to Gradio for display
# 7) Now that generation is complete, append final assistant message
conversation.append({"role": "assistant", "content": partial_response})
# -----------------------------------------------------------------------------
# 6. BUILD GRADIO INTERFACE
# -----------------------------------------------------------------------------
def build_ui():
with gr.Blocks(title="PhilosBeta-Advanced", css="#chatbot{height:550px} .overflow-y-auto{max-height:550px}") as demo:
gr.Markdown("# **PhilosBeta: Advanced Demo**")
gr.Markdown(
"An example of multi-turn conversation with streaming responses, "
"optional retrieval, and custom system/developer prompts."
)
# State to store the conversation as a list of role/content dicts
conversation_state = gr.State([])
# TEXT ELEMENTS
with gr.Row():
with gr.Column():
system_prompt_box = gr.Textbox(
label="System Prompt",
value=DEFAULT_SYSTEM_PROMPT,
lines=3
)
developer_prompt_box = gr.Textbox(
label="Developer Prompt",
value=DEFAULT_DEVELOPER_PROMPT,
lines=3
)
with gr.Column():
retrieve_flag = gr.Checkbox(label="Include Knowledge Retrieval", value=False)
temperature_slider = gr.Slider(0.0, 2.0, 0.7, step=0.1, label="Temperature")
top_p_slider = gr.Slider(0.0, 1.0, 0.9, step=0.05, label="Top-p")
max_tokens_info = gr.Markdown("Max new tokens = 128 (fixed in code).")
# MAIN CHAT UI
chatbox = gr.Chatbot(label="Philos Conversation", elem_id="chatbot").style(height=500)
user_input = gr.Textbox(
label="Your Message",
placeholder="Type here...",
lines=3
)
send_btn = gr.Button("Send", variant="primary")
# ---------------------------------------------------------------------
# ACTION: Handle user input
# ---------------------------------------------------------------------
def user_send(
user_text, conversation, sys_prompt, dev_prompt, retrieve_flg, temperature, top_p
):
"""
This function calls advanced_chat() and streams tokens back to update the Chatbot UI.
"""
# We'll create a generator to update the Chatbot in real-time
message_stream = advanced_chat(
user_msg=user_text,
conversation=conversation,
system_prompt=sys_prompt,
dev_prompt=dev_prompt,
retrieve_flg=retrieve_flg,
temperature=temperature,
top_p=top_p
)
return message_stream, conversation
# Gradio can handle generator outputs for streaming.
# We map the streamed tokens to the Chatbot component in real-time.
chatbox_stream = gr.Chatbot.update()
send_btn.click(
fn=user_send,
inputs=[
user_input,
conversation_state,
system_prompt_box,
developer_prompt_box,
retrieve_flag,
temperature_slider,
top_p_slider,
],
outputs=[chatbox_stream, conversation_state],
)
# We also let the user press Enter to send messages
user_input.submit(
fn=user_send,
inputs=[
user_input,
conversation_state,
system_prompt_box,
developer_prompt_box,
retrieve_flag,
temperature_slider,
top_p_slider,
],
outputs=[chatbox_stream, conversation_state],
)
return demo
# -----------------------------------------------------------------------------
# 7. LAUNCH
# -----------------------------------------------------------------------------
if __name__ == "__main__":
ui = build_ui()
ui.launch()