Spaces:
Running
Running
import gradio as gr | |
import random | |
import re | |
import threading | |
import time | |
import spaces | |
import torch | |
import numpy as np | |
# Assuming the transformers library is installed | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
# --- Global Settings --- | |
# These variables are placed in the global scope and will be loaded once when the Gradio app starts | |
system_prompt = [] | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
MODEL_PATHS = { | |
"Embformer-MiniMind-Base (0.1B)": ["HighCWu/Embformer-MiniMind-Base-0.1B", "Embformer-MiniMind-Base-0.1B"], | |
"Embformer-MiniMind-Seqlen512 (0.1B)": ["HighCWu/Embformer-MiniMind-Seqlen512-0.1B", "Embformer-MiniMind-Seqlen512-0.1B"], | |
"Embformer-MiniMind (0.1B)": ["HighCWu/Embformer-MiniMind-0.1B", "Embformer-MiniMind-0.1B"], | |
"Embformer-MiniMind-RLHF (0.1B)": ["HighCWu/Embformer-MiniMind-RLHF-0.1B", "Embformer-MiniMind-RLHF-0.1B"], | |
"Embformer-MiniMind-R1 (0.1B)": ["HighCWu/Embformer-MiniMind-R1-0.1B", "Embformer-MiniMind-R1-0.1B"], | |
} | |
# --- Helper Functions (Mostly unchanged) --- | |
def process_assistant_content(content, model_source, selected_model_name): | |
""" | |
Processes the model output, converting <think> tags to HTML details elements, | |
and handling content after </think>, filtering out <answer> tags. | |
""" | |
is_r1_model = False | |
if model_source == "API": | |
if 'R1' in selected_model_name: | |
is_r1_model = True | |
else: | |
model_identifier = MODEL_PATHS.get(selected_model_name, ["", ""])[1] | |
if 'R1' in model_identifier: | |
is_r1_model = True | |
if not is_r1_model: | |
return content | |
# Fully closed <think>...</think> block | |
if '<think>' in content and '</think>' in content: | |
# Using re.split is more robust than finding indices | |
parts = re.split(r'(</think>)', content, 1) | |
think_part = parts[0] + parts[1] # All content from <think> to </think> | |
after_think_part = parts[2] if len(parts) > 2 else "" | |
# 1. Process the think part | |
processed_think = re.sub( | |
r'(<think>)(.*?)(</think>)', | |
r'<details style="font-style: italic; background: rgba(222, 222, 222, 0.5); padding: 10px; border-radius: 10px;"><summary style="font-weight:bold;">Reasoning (Click to expand)</summary>\2</details>', | |
think_part, | |
flags=re.DOTALL | |
) | |
# 2. Process the part after </think>, filtering <answer> tags | |
# Using re.sub to replace <answer> and </answer> with an empty string | |
processed_after_think = re.sub(r'</?answer>', '', after_think_part) | |
# 3. Concatenate the results | |
return processed_think + processed_after_think | |
# Only an opening <think>, indicating reasoning is in progress | |
if '<think>' in content and '</think>' not in content: | |
return re.sub( | |
r'<think>(.*?)$', | |
r'<details open style="font-style: italic; background: rgba(222, 222, 222, 0.5); padding: 10px; border-radius: 10px;"><summary style="font-weight:bold;">Reasoning...</summary>\1</details>', | |
content, | |
flags=re.DOTALL | |
) | |
# This case should be rare in streaming output, but kept for completeness | |
if '<think>' not in content and '</think>' in content: | |
# Also need to process content after </think> | |
parts = re.split(r'(</think>)', content, 1) | |
think_part = parts[0] + parts[1] | |
after_think_part = parts[2] if len(parts) > 2 else "" | |
processed_think = re.sub( | |
r'(.*?)</think>', | |
r'<details style="font-style: italic; background: rgba(222, 222, 222, 0.5); padding: 10px; border-radius: 10px;"><summary style="font-weight:bold;">Reasoning (Click to expand)</summary>\1</details>', | |
think_part, | |
flags=re.DOTALL | |
) | |
processed_after_think = re.sub(r'</?answer>', '', after_think_part) | |
return processed_think + processed_after_think | |
# If there are no <think> tags, return the content directly | |
return content | |
def setup_seed(seed): | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
if device != "cpu": | |
torch.cuda.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
# --- Gradio App Logic --- | |
# Gradio uses global variables or functions to load models, similar to st.cache_resource | |
# We cache models and tokenizers in a dictionary to avoid reloading | |
loaded_models = {} | |
def load_model_tokenizer_gradio(model_name): | |
""" | |
Gradio version of the model loading function with caching. | |
""" | |
if model_name in loaded_models: | |
# print(f"Using cached model: {model_name}") | |
return loaded_models[model_name] | |
# print(f"Loading model: {model_name}...") | |
model_path = MODEL_PATHS[model_name][0] | |
model = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
trust_remote_code=True, | |
cache_dir=".cache", | |
).to(device).eval() | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_path, | |
trust_remote_code=True, | |
cache_dir=".cache", | |
) | |
loaded_models[model_name] = (model, tokenizer) | |
print("Model loaded.") | |
return model, tokenizer | |
def chat_fn( | |
user_message, | |
history, | |
model_source, | |
# Local model settings | |
selected_model, | |
# API settings | |
api_url, | |
api_model_id, | |
api_model_name, | |
api_key, | |
# Generation parameters | |
history_chat_num, | |
max_new_tokens, | |
temperature | |
): | |
""" | |
Gradio's core chat processing function. | |
It receives the current values of all UI components as input. | |
""" | |
history = history or [] | |
# Build context for the model based on the passed, unmodified history | |
chat_messages_for_model = [] | |
# Limit the number of history turns | |
if history_chat_num > 0 and len(history) > history_chat_num: | |
relevant_history_turns = history[-history_chat_num:] | |
else: | |
relevant_history_turns = history | |
for user_msg, assistant_msg in relevant_history_turns: | |
chat_messages_for_model.append({"role": "user", "content": user_msg}) | |
if assistant_msg: | |
chat_messages_for_model.append({"role": "assistant", "content": assistant_msg}) | |
# Add the current user message to the model's context | |
chat_messages_for_model.append({"role": "user", "content": user_message}) | |
final_chat_messages = system_prompt + chat_messages_for_model | |
# Now, update the history for UI display | |
history.extend([*chat_messages_for_model, {"role": "assistant", "content": user_message}]) | |
# --- Model Invocation --- | |
if model_source == "API": | |
try: | |
from openai import OpenAI | |
client = OpenAI(api_key=api_key, base_url=api_url) | |
response = client.chat.completions.create( | |
model=api_model_id, | |
messages=final_chat_messages, | |
stream=True, | |
temperature=temperature | |
) | |
answer = "" | |
for chunk in response: | |
content = chunk.choices[0].delta.content or "" | |
answer += content | |
processed_answer = process_assistant_content(answer, model_source, api_model_name) | |
history[-1]["content"] = processed_answer | |
yield history, history | |
except Exception as e: | |
history[-1]["content"] = f"API call error: {str(e)}" | |
yield history, history | |
else: # Local Model | |
try: | |
model, tokenizer = load_model_tokenizer_gradio(selected_model) | |
random_seed = random.randint(0, 2**32 - 1) | |
setup_seed(random_seed) | |
new_prompt = tokenizer.apply_chat_template( | |
final_chat_messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
inputs = tokenizer(new_prompt, return_tensors="pt", truncation=True).to(device) | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
generation_kwargs = { | |
"input_ids": inputs.input_ids, | |
"attention_mask": inputs.attention_mask, | |
"max_new_tokens": max_new_tokens, | |
"num_return_sequences": 1, | |
"do_sample": True, | |
"pad_token_id": tokenizer.pad_token_id, | |
"eos_token_id": tokenizer.eos_token_id, | |
"temperature": temperature, | |
"top_p": 0.85, | |
"streamer": streamer, | |
} | |
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
answer = "" | |
for new_text in streamer: | |
answer += new_text | |
processed_answer = process_assistant_content(answer, model_source, selected_model) | |
history[-1]["content"] = processed_answer | |
yield history, history | |
except Exception as e: | |
history[-1]["content"] = f"Local model call error: {str(e)}" | |
yield history, history | |
# --- Gradio UI Layout --- | |
css = """ | |
.gradio-container { font-family: 'sans-serif'; } | |
footer { display: none !important; } | |
""" | |
image_url = "https://chunte-hfba.static.hf.space/images/modern%20Huggies/Huggy%20Sunny%20hello.png" | |
# Define example data | |
prompt_datas = [ | |
'请介绍一下自己。', | |
'你更擅长哪一个学科?', | |
'鲁迅的《狂人日记》是如何批判封建礼教的?', | |
'我咳嗽已经持续了两周,需要去医院检查吗?', | |
'详细的介绍光速的物理概念。', | |
'推荐一些杭州的特色美食吧。', | |
'请为我讲解“大语言模型”这个概念。', | |
'如何理解ChatGPT?', | |
'Introduce the history of the United States, please.' | |
] | |
with gr.Blocks(theme='soft', css=css) as demo: | |
# History state, this is the Gradio equivalent of st.session_state | |
chat_history = gr.State([]) | |
chat_input_cache = gr.State("") | |
# Top Title and Badge | |
title_html = """ | |
<div style="text-align: center;"> | |
<h1>Embformer: An Embedding-Weight-Only Transformer Architecture</h1> | |
<div style="display: flex; justify-content: center; align-items: center; gap: 8px; margin-top: 10px;"> | |
<a href="https://doi.org/10.5281/zenodo.15736957"> | |
<img src="https://img.shields.io/badge/DOI-10.5281%2Fzenodo.15736957-blue.svg" alt="DOI"> | |
</a> | |
<a href="https://github.com/HighCWu/embformer"> | |
<img src="https://img.shields.io/badge/github-%23121011.svg?style=for-the-badge&logo=github&logoColor=white" alt="code"> | |
</a> | |
<a href="https://huggingface.co/collections/HighCWu/embformer-minimind-685be74dc761610439241bd5"> | |
<img src="https://img.shields.io/badge/Model-🤗-yellow" alt="model"> | |
</a> | |
</div> | |
</div> | |
""" | |
gr.HTML(title_html) | |
gr.Markdown(""" | |
This is the official demo of [Embformer: An Embedding-Weight-Only Transformer Architecture](https://doi.org/10.5281/zenodo.15736957). | |
**Note**: Since the model dataset used in this demo is derived from the MiniMind dataset, which contains a large proportion of Chinese content, please try to use Chinese as much as possible in the conversation. | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1, min_width=200): | |
gr.Markdown("### Model Settings") | |
# Model source switcher | |
model_source_radio = gr.Radio(["Local Model", "API"], value="Local Model", label="Select Model Source", visible=False) | |
# Local model settings | |
with gr.Group(visible=True) as local_model_group: | |
selected_model_dd = gr.Dropdown( | |
list(MODEL_PATHS.keys()), | |
value="Embformer-MiniMind (0.1B)", | |
label="Select Local Model" | |
) | |
# API settings | |
with gr.Group(visible=False) as api_model_group: | |
api_url_tb = gr.Textbox("http://127.0.0.1:8000/v1", label="API URL") | |
api_model_id_tb = gr.Textbox("embformer-minimind", label="Model ID") | |
api_model_name_tb = gr.Textbox("Embformer-MiniMind (0.1B)", label="Model Name (for feature detection)") | |
api_key_tb = gr.Textbox("none", label="API Key", type="password") | |
# Common generation parameters | |
history_chat_num_slider = gr.Slider(0, 6, value=0, step=2, label="History Turns") | |
max_new_tokens_slider = gr.Slider(256, 8192, value=1024, step=1, label="Max New Tokens") | |
temperature_slider = gr.Slider(0.6, 1.2, value=0.85, step=0.01, label="Temperature") | |
# Clear history button | |
clear_btn = gr.Button("🗑️ Clear History") | |
with gr.Column(scale=4): | |
gr.Markdown("### Chat") | |
chatbot = gr.Chatbot( | |
[], | |
elem_id="chatbot", | |
avatar_images=(None, image_url), | |
type="messages", | |
height=350 | |
) | |
chat_input = gr.Textbox( | |
show_label=False, | |
placeholder="Send a message to MiniMind... (Enter to send)", | |
container=False, | |
scale=7, | |
elem_id="chat-textbox", | |
) | |
examples = gr.Examples( | |
examples=prompt_datas, | |
inputs=chat_input, # After clicking, the example content will fill chat_input | |
label="Click an example to ask (will automatically clear chat and continue)" | |
) | |
# --- Event Listeners and Bindings --- | |
# Show/hide corresponding setting groups when switching model source | |
def toggle_model_source_ui(source): | |
return { | |
local_model_group: gr.update(visible=source == "Local Model"), | |
api_model_group: gr.update(visible=source == "API") | |
} | |
model_source_radio.change( | |
fn=toggle_model_source_ui, | |
inputs=model_source_radio, | |
outputs=[local_model_group, api_model_group] | |
) | |
# Define the list of input components for the submit event | |
submit_inputs = [ | |
chat_input_cache, chat_history, model_source_radio, selected_model_dd, | |
api_url_tb, api_model_id_tb, api_model_name_tb, api_key_tb, | |
history_chat_num_slider, max_new_tokens_slider, temperature_slider | |
] | |
# When chat_input is submitted (user presses enter or an example is clicked), run chat_fn | |
submit_event = chat_input.submit( | |
fn=lambda text: ("", text), | |
inputs=chat_input, | |
outputs=[chat_input, chat_input_cache], | |
).then( | |
fn=chat_fn, | |
inputs=submit_inputs, | |
outputs=[chatbot, chat_history], | |
) | |
# Event chain for clicking an example | |
examples.load_input_event.then( | |
fn=lambda text: ("", text, [], []), # A function to clear the history | |
inputs=chat_input, | |
outputs=[chat_input, chat_input_cache, chatbot, chat_history], # This affects the chatbot and chat_history | |
).then( | |
fn=chat_fn, # Use the dedicated run_example function | |
inputs=submit_inputs, # Pass example text and other settings | |
outputs=[chatbot, chat_history], | |
) | |
# Clear history button logic | |
def clear_history(): | |
return [], [] | |
clear_btn.click(fn=clear_history, outputs=[chatbot, chat_history]) | |
chatbot.clear(fn=clear_history, outputs=[chatbot, chat_history]) | |
if __name__ == "__main__": | |
# Pre-load the default model on startup | |
print("Pre-loading default model...") | |
load_model_tokenizer_gradio("Embformer-MiniMind (0.1B)") | |
# Launch the Gradio app | |
demo.queue().launch(share=False) | |