import gradio as gr
import random
import re
import threading
import time
import spaces
import torch
import numpy as np
# Assuming the transformers library is installed
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
# --- Global Settings ---
# These variables are placed in the global scope and will be loaded once when the Gradio app starts
system_prompt = []
device = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_PATHS = {
"Embformer-MiniMind-Base (0.1B)": ["HighCWu/Embformer-MiniMind-Base-0.1B", "Embformer-MiniMind-Base-0.1B"],
"Embformer-MiniMind-Seqlen512 (0.1B)": ["HighCWu/Embformer-MiniMind-Seqlen512-0.1B", "Embformer-MiniMind-Seqlen512-0.1B"],
"Embformer-MiniMind (0.1B)": ["HighCWu/Embformer-MiniMind-0.1B", "Embformer-MiniMind-0.1B"],
"Embformer-MiniMind-RLHF (0.1B)": ["HighCWu/Embformer-MiniMind-RLHF-0.1B", "Embformer-MiniMind-RLHF-0.1B"],
"Embformer-MiniMind-R1 (0.1B)": ["HighCWu/Embformer-MiniMind-R1-0.1B", "Embformer-MiniMind-R1-0.1B"],
}
# --- Helper Functions (Mostly unchanged) ---
def process_assistant_content(content, model_source, selected_model_name):
"""
Processes the model output, converting tags to HTML details elements,
and handling content after , filtering out tags.
"""
is_r1_model = False
if model_source == "API":
if 'R1' in selected_model_name:
is_r1_model = True
else:
model_identifier = MODEL_PATHS.get(selected_model_name, ["", ""])[1]
if 'R1' in model_identifier:
is_r1_model = True
if not is_r1_model:
return content
# Fully closed ... block
if '' in content and '' in content:
# Using re.split is more robust than finding indices
parts = re.split(r'()', content, 1)
think_part = parts[0] + parts[1] # All content from to
after_think_part = parts[2] if len(parts) > 2 else ""
# 1. Process the think part
processed_think = re.sub(
r'()(.*?)()',
r'Reasoning (Click to expand)
\2 ',
think_part,
flags=re.DOTALL
)
# 2. Process the part after , filtering tags
# Using re.sub to replace and with an empty string
processed_after_think = re.sub(r'?answer>', '', after_think_part)
# 3. Concatenate the results
return processed_think + processed_after_think
# Only an opening , indicating reasoning is in progress
if '' in content and '' not in content:
return re.sub(
r'(.*?)$',
r'Reasoning...
\1 ',
content,
flags=re.DOTALL
)
# This case should be rare in streaming output, but kept for completeness
if '' not in content and '' in content:
# Also need to process content after
parts = re.split(r'()', content, 1)
think_part = parts[0] + parts[1]
after_think_part = parts[2] if len(parts) > 2 else ""
processed_think = re.sub(
r'(.*?)',
r'Reasoning (Click to expand)
\1 ',
think_part,
flags=re.DOTALL
)
processed_after_think = re.sub(r'?answer>', '', after_think_part)
return processed_think + processed_after_think
# If there are no tags, return the content directly
return content
def setup_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if device != "cpu":
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# --- Gradio App Logic ---
# Gradio uses global variables or functions to load models, similar to st.cache_resource
# We cache models and tokenizers in a dictionary to avoid reloading
loaded_models = {}
def load_model_tokenizer_gradio(model_name):
"""
Gradio version of the model loading function with caching.
"""
if model_name in loaded_models:
# print(f"Using cached model: {model_name}")
return loaded_models[model_name]
# print(f"Loading model: {model_name}...")
model_path = MODEL_PATHS[model_name][0]
model = AutoModelForCausalLM.from_pretrained(
model_path,
trust_remote_code=True,
cache_dir=".cache",
).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True,
cache_dir=".cache",
)
loaded_models[model_name] = (model, tokenizer)
print("Model loaded.")
return model, tokenizer
@spaces.GPU
def chat_fn(
user_message,
history,
model_source,
# Local model settings
selected_model,
# API settings
api_url,
api_model_id,
api_model_name,
api_key,
# Generation parameters
history_chat_num,
max_new_tokens,
temperature
):
"""
Gradio's core chat processing function.
It receives the current values of all UI components as input.
"""
history = history or []
# Build context for the model based on the passed, unmodified history
chat_messages_for_model = []
# Limit the number of history turns
if history_chat_num > 0 and len(history) > history_chat_num:
relevant_history_turns = history[-history_chat_num:]
else:
relevant_history_turns = history
for user_msg, assistant_msg in relevant_history_turns:
chat_messages_for_model.append({"role": "user", "content": user_msg})
if assistant_msg:
chat_messages_for_model.append({"role": "assistant", "content": assistant_msg})
# Add the current user message to the model's context
chat_messages_for_model.append({"role": "user", "content": user_message})
final_chat_messages = system_prompt + chat_messages_for_model
# Now, update the history for UI display
history.extend([*chat_messages_for_model, {"role": "assistant", "content": user_message}])
# --- Model Invocation ---
if model_source == "API":
try:
from openai import OpenAI
client = OpenAI(api_key=api_key, base_url=api_url)
response = client.chat.completions.create(
model=api_model_id,
messages=final_chat_messages,
stream=True,
temperature=temperature
)
answer = ""
for chunk in response:
content = chunk.choices[0].delta.content or ""
answer += content
processed_answer = process_assistant_content(answer, model_source, api_model_name)
history[-1]["content"] = processed_answer
yield history, history
except Exception as e:
history[-1]["content"] = f"API call error: {str(e)}"
yield history, history
else: # Local Model
try:
model, tokenizer = load_model_tokenizer_gradio(selected_model)
random_seed = random.randint(0, 2**32 - 1)
setup_seed(random_seed)
new_prompt = tokenizer.apply_chat_template(
final_chat_messages,
tokenize=False,
add_generation_prompt=True
)
inputs = tokenizer(new_prompt, return_tensors="pt", truncation=True).to(device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = {
"input_ids": inputs.input_ids,
"attention_mask": inputs.attention_mask,
"max_new_tokens": max_new_tokens,
"num_return_sequences": 1,
"do_sample": True,
"pad_token_id": tokenizer.pad_token_id,
"eos_token_id": tokenizer.eos_token_id,
"temperature": temperature,
"top_p": 0.85,
"streamer": streamer,
}
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
answer = ""
for new_text in streamer:
answer += new_text
processed_answer = process_assistant_content(answer, model_source, selected_model)
history[-1]["content"] = processed_answer
yield history, history
except Exception as e:
history[-1]["content"] = f"Local model call error: {str(e)}"
yield history, history
# --- Gradio UI Layout ---
css = """
.gradio-container { font-family: 'sans-serif'; }
footer { display: none !important; }
"""
image_url = "https://chunte-hfba.static.hf.space/images/modern%20Huggies/Huggy%20Sunny%20hello.png"
# Define example data
prompt_datas = [
'请介绍一下自己。',
'你更擅长哪一个学科?',
'鲁迅的《狂人日记》是如何批判封建礼教的?',
'我咳嗽已经持续了两周,需要去医院检查吗?',
'详细的介绍光速的物理概念。',
'推荐一些杭州的特色美食吧。',
'请为我讲解“大语言模型”这个概念。',
'如何理解ChatGPT?',
'Introduce the history of the United States, please.'
]
with gr.Blocks(theme='soft', css=css) as demo:
# History state, this is the Gradio equivalent of st.session_state
chat_history = gr.State([])
chat_input_cache = gr.State("")
# Top Title and Badge
title_html = """
Embformer: An Embedding-Weight-Only Transformer Architecture
"""
gr.HTML(title_html)
gr.Markdown("""
This is the official demo of [Embformer: An Embedding-Weight-Only Transformer Architecture](https://doi.org/10.5281/zenodo.15736957).
**Note**: Since the model dataset used in this demo is derived from the MiniMind dataset, which contains a large proportion of Chinese content, please try to use Chinese as much as possible in the conversation.
""")
with gr.Row():
with gr.Column(scale=1, min_width=200):
gr.Markdown("### Model Settings")
# Model source switcher
model_source_radio = gr.Radio(["Local Model", "API"], value="Local Model", label="Select Model Source", visible=False)
# Local model settings
with gr.Group(visible=True) as local_model_group:
selected_model_dd = gr.Dropdown(
list(MODEL_PATHS.keys()),
value="Embformer-MiniMind (0.1B)",
label="Select Local Model"
)
# API settings
with gr.Group(visible=False) as api_model_group:
api_url_tb = gr.Textbox("http://127.0.0.1:8000/v1", label="API URL")
api_model_id_tb = gr.Textbox("embformer-minimind", label="Model ID")
api_model_name_tb = gr.Textbox("Embformer-MiniMind (0.1B)", label="Model Name (for feature detection)")
api_key_tb = gr.Textbox("none", label="API Key", type="password")
# Common generation parameters
history_chat_num_slider = gr.Slider(0, 6, value=0, step=2, label="History Turns")
max_new_tokens_slider = gr.Slider(256, 8192, value=1024, step=1, label="Max New Tokens")
temperature_slider = gr.Slider(0.6, 1.2, value=0.85, step=0.01, label="Temperature")
# Clear history button
clear_btn = gr.Button("🗑️ Clear History")
with gr.Column(scale=4):
gr.Markdown("### Chat")
chatbot = gr.Chatbot(
[],
elem_id="chatbot",
avatar_images=(None, image_url),
type="messages",
height=350
)
chat_input = gr.Textbox(
show_label=False,
placeholder="Send a message to MiniMind... (Enter to send)",
container=False,
scale=7,
elem_id="chat-textbox",
)
examples = gr.Examples(
examples=prompt_datas,
inputs=chat_input, # After clicking, the example content will fill chat_input
label="Click an example to ask (will automatically clear chat and continue)"
)
# --- Event Listeners and Bindings ---
# Show/hide corresponding setting groups when switching model source
def toggle_model_source_ui(source):
return {
local_model_group: gr.update(visible=source == "Local Model"),
api_model_group: gr.update(visible=source == "API")
}
model_source_radio.change(
fn=toggle_model_source_ui,
inputs=model_source_radio,
outputs=[local_model_group, api_model_group]
)
# Define the list of input components for the submit event
submit_inputs = [
chat_input_cache, chat_history, model_source_radio, selected_model_dd,
api_url_tb, api_model_id_tb, api_model_name_tb, api_key_tb,
history_chat_num_slider, max_new_tokens_slider, temperature_slider
]
# When chat_input is submitted (user presses enter or an example is clicked), run chat_fn
submit_event = chat_input.submit(
fn=lambda text: ("", text),
inputs=chat_input,
outputs=[chat_input, chat_input_cache],
).then(
fn=chat_fn,
inputs=submit_inputs,
outputs=[chatbot, chat_history],
)
# Event chain for clicking an example
examples.load_input_event.then(
fn=lambda text: ("", text, [], []), # A function to clear the history
inputs=chat_input,
outputs=[chat_input, chat_input_cache, chatbot, chat_history], # This affects the chatbot and chat_history
).then(
fn=chat_fn, # Use the dedicated run_example function
inputs=submit_inputs, # Pass example text and other settings
outputs=[chatbot, chat_history],
)
# Clear history button logic
def clear_history():
return [], []
clear_btn.click(fn=clear_history, outputs=[chatbot, chat_history])
chatbot.clear(fn=clear_history, outputs=[chatbot, chat_history])
if __name__ == "__main__":
# Pre-load the default model on startup
print("Pre-loading default model...")
load_model_tokenizer_gradio("Embformer-MiniMind (0.1B)")
# Launch the Gradio app
demo.queue().launch(share=False)