Yyyy / app.py
Athspi's picture
Update app.py
ce67cd9 verified
raw
history blame
14.8 kB
import gradio as gr
import onnxruntime_genai as og
import time
import os
from huggingface_hub import snapshot_download
import argparse
import logging
# --- Logging Setup ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# --- Configuration ---
MODEL_REPO = "microsoft/Phi-4-mini-instruct-onnx"
# --- Defaulting to CPU INT4 for Hugging Face Spaces ---
EXECUTION_PROVIDER = "cpu" # Corresponds to installing 'onnxruntime-genai'
MODEL_VARIANT_GLOB = "cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/*"
# --- --- --- --- --- --- --- --- --- --- --- --- --- --- ---
# --- (Optional) Alternative GPU Configuration ---
# EXECUTION_PROVIDER = "cuda" # Corresponds to installing 'onnxruntime-genai-cuda'
# MODEL_VARIANT_GLOB = "gpu/gpu-int4-rtn-block-32/*"
# --- --- --- --- --- --- --- --- --- --- --- --- --- --- ---
LOCAL_MODEL_DIR = "./phi4-mini-onnx-model" # Directory within the Space
HF_LOGO_URL = "https://huggingface.co/front/assets/huggingface_logo-noborder.svg"
HF_MODEL_URL = f"https://huggingface.co/{MODEL_REPO}"
ORT_GENAI_URL = "https://github.com/microsoft/onnxruntime-genai"
PHI_LOGO_URL = "https://microsoft.github.io/phi/assets/img/logo-final.png" # Phi logo for bot avatar
# Global variables for model and tokenizer
model = None
tokenizer = None
model_variant_name = os.path.basename(os.path.dirname(MODEL_VARIANT_GLOB)) # For display
model_status = "Initializing..."
# --- Model Download and Load ---
def initialize_model():
"""Downloads and loads the ONNX model and tokenizer."""
global model, tokenizer, model_status
logging.info("--- Initializing ONNX Runtime GenAI ---")
model_status = "Downloading model..."
logging.info(model_status)
# --- Download ---
model_variant_dir = os.path.join(LOCAL_MODEL_DIR, os.path.dirname(MODEL_VARIANT_GLOB))
if os.path.exists(model_variant_dir) and os.listdir(model_variant_dir):
logging.info(f"Model variant found in {model_variant_dir}. Skipping download.")
model_path = model_variant_dir
else:
logging.info(f"Downloading model variant '{MODEL_VARIANT_GLOB}' from {MODEL_REPO}...")
try:
snapshot_download(
MODEL_REPO,
allow_patterns=[MODEL_VARIANT_GLOB],
local_dir=LOCAL_MODEL_DIR,
local_dir_use_symlinks=False
)
model_path = model_variant_dir
logging.info(f"Model downloaded to: {model_path}")
except Exception as e:
logging.error(f"Error downloading model: {e}", exc_info=True)
model_status = f"Error downloading model: {e}"
raise RuntimeError(f"Failed to download model: {e}")
# --- Load ---
model_status = f"Loading model ({EXECUTION_PROVIDER.upper()})..."
logging.info(model_status)
try:
# FIX: Removed explicit DeviceType. Let the library infer or use string if needed by constructor.
# The simple constructor often works by detecting the installed ORT package.
logging.info(f"Using provider based on installed package (expecting: {EXECUTION_PROVIDER})")
model = og.Model(model_path) # Simplified model loading
tokenizer = og.Tokenizer(model)
model_status = f"Model Ready ({EXECUTION_PROVIDER.upper()} / {model_variant_name})"
logging.info("Model and Tokenizer loaded successfully.")
except AttributeError as ae:
logging.error(f"AttributeError during model/tokenizer init: {ae}", exc_info=True)
logging.error("This might indicate an installation issue or version incompatibility with onnxruntime_genai.")
model_status = f"Init Error: {ae}"
raise RuntimeError(f"Failed to initialize model/tokenizer: {ae}")
except Exception as e:
logging.error(f"Error loading model or tokenizer: {e}", exc_info=True)
model_status = f"Error loading model: {e}"
raise RuntimeError(f"Failed to load model: {e}")
# --- Generation Function (Core Logic) ---
def generate_response_stream(prompt, history, max_length, temperature, top_p, top_k):
"""Generates a response using the Phi-4 ONNX model, yielding text chunks."""
global model_status
if not model or not tokenizer:
model_status = "Error: Model not initialized!"
yield "Error: Model not initialized. Please check logs."
return
# --- Prepare the prompt using the Phi-4 instruct format ---
full_prompt = ""
# History format is [[user1, bot1], [user2, bot2], ...]
for user_msg, assistant_msg in history: # history here is *before* the current prompt
full_prompt += f"<|user|>\n{user_msg}<|end|>\n"
if assistant_msg: # Append assistant message only if it exists
full_prompt += f"<|assistant|>\n{assistant_msg}<|end|>\n"
# Add the current user prompt and the trigger for the assistant's response
full_prompt += f"<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
logging.info(f"Generating response (MaxL: {max_length}, Temp: {temperature}, TopP: {top_p}, TopK: {top_k})")
try:
input_tokens = tokenizer.encode(full_prompt)
# FIX: Removed eos_token_id and pad_token_id as they are not attributes
# of onnxruntime_genai.Tokenizer and likely handled internally by the generator.
search_options = {
"max_length": max_length,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"do_sample": True,
}
params = og.GeneratorParams(model)
params.set_search_options(**search_options)
params.input_ids = input_tokens
start_time = time.time()
generator = og.Generator(model, params)
model_status = "Generating..." # Update status indicator
logging.info("Streaming response...")
first_token_time = None
token_count = 0
# Rely primarily on generator.is_done()
while not generator.is_done():
generator.compute_logits()
generator.generate_next_token()
if first_token_time is None:
first_token_time = time.time() # Record time to first token
next_token = generator.get_next_tokens()[0]
decoded_chunk = tokenizer.decode([next_token])
token_count += 1
# Secondary check: Stop if the model explicitly generates the <|end|> string literal.
if decoded_chunk == "<|end|>":
logging.info("Assistant explicitly generated <|end|> token string.")
break
yield decoded_chunk # Yield just the text chunk
end_time = time.time()
ttft = (first_token_time - start_time) * 1000 if first_token_time else -1
total_time = end_time - start_time
tps = (token_count / total_time) if total_time > 0 else 0
logging.info(f"Generation complete. Tokens: {token_count}, Total Time: {total_time:.2f}s, TTFT: {ttft:.2f}ms, TPS: {tps:.2f}")
model_status = f"Model Ready ({EXECUTION_PROVIDER.upper()} / {model_variant_name})" # Reset status
except Exception as e:
logging.error(f"Error during generation: {e}", exc_info=True)
model_status = f"Error during generation: {e}"
yield f"\n\nSorry, an error occurred during generation: {e}" # Yield error message
# --- Gradio Interface Functions ---
# 1. Function to add user message to chat history
def add_user_message(user_message, history):
"""Adds the user's message to the chat history for display."""
if not user_message:
# Returning original history prevents adding empty message
# Use gr.Warning or gr.Info for user feedback? Or raise gr.Error?
# gr.Warning("Please enter a message.") # Shows warning toast
return "", history # Clear input, return unchanged history
# raise gr.Error("Please enter a message.") # Stops execution, shows error
history = history + [[user_message, None]] # Append user message, leave bot response None
return "", history # Clear input textbox, return updated history
# 2. Function to handle bot response generation and streaming
def generate_bot_response(history, max_length, temperature, top_p, top_k):
"""Generates the bot's response based on the history and streams it."""
if not history or history[-1][1] is not None:
# This case means user submitted empty message or something went wrong
# No need to generate if the last turn isn't user's pending turn
return history
user_prompt = history[-1][0] # Get the latest user prompt
# Prepare history for the model (all turns *before* the current one)
model_history = history[:-1]
# Get the generator stream
response_stream = generate_response_stream(
user_prompt, model_history, max_length, temperature, top_p, top_k
)
# Stream the response chunks back to Gradio
history[-1][1] = "" # Initialize the bot response string in the history
for chunk in response_stream:
history[-1][1] += chunk # Append the chunk to the bot's message in history
yield history # Yield the *entire updated history* back to Chatbot
# 3. Function to clear chat
def clear_chat():
"""Clears the chat history and input."""
global model_status # Keep model status indicator updated
# Reset status only if it was showing an error from generation maybe?
# Or just always reset to Ready if model is loaded.
if model and tokenizer and not model_status.startswith("Error") and not model_status.startswith("FATAL"):
model_status = f"Model Ready ({EXECUTION_PROVIDER.upper()} / {model_variant_name})"
# Keep the original error if init failed, otherwise show ready status
return None, [], model_status # Clear Textbox, Chatbot history, and update status display
# --- Initialize Model on App Start ---
try:
initialize_model()
except Exception as e:
print(f"FATAL: Model initialization failed: {e}")
# model_status is already set inside initialize_model on error
# --- Gradio Interface ---
logging.info("Creating Gradio Interface...")
# Select a theme
theme = gr.themes.Soft(
primary_hue="blue",
secondary_hue="sky",
neutral_hue="slate",
font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"],
)
with gr.Blocks(theme=theme, title="Phi-4 Mini ONNX Chat") as demo:
# Header Section
with gr.Row(equal_height=False):
with gr.Column(scale=3):
gr.Markdown(f"""
# Phi-4 Mini Instruct ONNX Chat 🤖
Interact with the quantized `{model_variant_name}` version of [`{MODEL_REPO}`]({HF_MODEL_URL})
running efficiently via [`onnxruntime-genai`]({ORT_GENAI_URL}) ({EXECUTION_PROVIDER.upper()}).
""")
with gr.Column(scale=1, min_width=150):
gr.Image(HF_LOGO_URL, elem_id="hf-logo", show_label=False, show_download_button=False, container=False, height=50)
# Use the global model_status variable for the initial value
model_status_text = gr.Textbox(value=model_status, label="Model Status", interactive=False, max_lines=2)
# Main Layout (Chat on Left, Settings on Right)
with gr.Row():
# Chat Column
with gr.Column(scale=3):
chatbot = gr.Chatbot(
label="Conversation",
height=600,
layout="bubble",
bubble_full_width=False,
avatar_images=(None, PHI_LOGO_URL) # (user, bot)
)
with gr.Row():
prompt_input = gr.Textbox(
label="Your Message",
placeholder="<|user|>\nType your message here...\n<|end|>",
lines=4,
scale=9 # Make textbox wider
)
# Combine Send and Clear Buttons Vertically? Or keep side-by-side? Side-by-side looks better
with gr.Column(scale=1, min_width=120):
submit_button = gr.Button("Send", variant="primary", size="lg")
clear_button = gr.Button("🗑️ Clear Chat", variant="secondary")
# Settings Column
with gr.Column(scale=1, min_width=250):
gr.Markdown("### ⚙️ Generation Settings")
with gr.Group(): # Group settings visually
max_length = gr.Slider(minimum=64, maximum=4096, value=1024, step=64, label="Max Length", info="Max tokens in response.")
temperature = gr.Slider(minimum=0.0, maximum=1.5, value=0.7, step=0.05, label="Temperature", info="0.0 = deterministic\n>1.0 = more random")
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top-P", info="Nucleus sampling probability.")
top_k = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top-K", info="Limit to K most likely tokens (0=disable).")
gr.Markdown("---") # Separator
gr.Markdown("ℹ️ **Note:** Uses Phi-4 instruction format: \n`<|user|>\nPROMPT<|end|>\n<|assistant|>`")
gr.Markdown(f"Running on **{EXECUTION_PROVIDER.upper()}**.")
# Event Listeners (Connecting UI components to functions)
# Define inputs for the bot response generator
bot_response_inputs = [chatbot, max_length, temperature, top_p, top_k]
# Chain actions:
# 1. User presses Enter or clicks Send
# 2. `add_user_message` updates history, clears input
# 3. `generate_bot_response` streams bot reply into history
submit_event = prompt_input.submit(
fn=add_user_message,
inputs=[prompt_input, chatbot],
outputs=[prompt_input, chatbot], # Update textbox and history
queue=False, # Submit is fast
).then(
fn=generate_bot_response, # Call the generator function
inputs=bot_response_inputs, # Pass history and params
outputs=[chatbot], # Stream output directly to chatbot
api_name="chat" # Optional: name for API usage
)
submit_button.click( # Mirror actions for button click
fn=add_user_message,
inputs=[prompt_input, chatbot],
outputs=[prompt_input, chatbot],
queue=False,
).then(
fn=generate_bot_response,
inputs=bot_response_inputs,
outputs=[chatbot],
api_name=False # Don't expose button click as separate API endpoint
)
# Clear button action
clear_button.click(
fn=clear_chat,
inputs=None,
outputs=[prompt_input, chatbot, model_status_text], # Clear input, chat, and update status text
queue=False # Clearing is fast
)
# Launch the Gradio app
logging.info("Launching Gradio App...")
demo.queue(max_size=20) # Enable queuing with a limit
demo.launch(show_error=True, max_threads=40)