HighCWu commited on
Commit
2c2f3fa
·
1 Parent(s): dc0cb56
Files changed (2) hide show
  1. app.py +407 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import random
3
+ import re
4
+ import threading
5
+ import time
6
+
7
+ import spaces
8
+ import torch
9
+ import numpy as np
10
+
11
+ # Assuming the transformers library is installed
12
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
13
+
14
+ # --- Global Settings ---
15
+ # These variables are placed in the global scope and will be loaded once when the Gradio app starts
16
+ system_prompt = []
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+
19
+ MODEL_PATHS = {
20
+ "Embformer-MiniMind-Base (0.1B)": ["HighCWu/Embformer-MiniMind-Base-0.1B", "Embformer-MiniMind-Base-0.1B"],
21
+ "Embformer-MiniMind-Seqlen512 (0.1B)": ["HighCWu/Embformer-MiniMind-Seqlen512-0.1B", "Embformer-MiniMind-Seqlen512-0.1B"],
22
+ "Embformer-MiniMind (0.1B)": ["HighCWu/Embformer-MiniMind-0.1B", "Embformer-MiniMind-0.1B"],
23
+ "Embformer-MiniMind-RLHF (0.1B)": ["HighCWu/Embformer-MiniMind-RLHF-0.1B", "Embformer-MiniMind-RLHF-0.1B"],
24
+ "Embformer-MiniMind-R1 (0.1B)": ["HighCWu/Embformer-MiniMind-R1-0.1B", "Embformer-MiniMind-R1-0.1B"],
25
+ }
26
+
27
+ # --- Helper Functions (Mostly unchanged) ---
28
+
29
+ def process_assistant_content(content, model_source, selected_model_name):
30
+ """
31
+ Processes the model output, converting <think> tags to HTML details elements,
32
+ and handling content after </think>, filtering out <answer> tags.
33
+ """
34
+ is_r1_model = False
35
+ if model_source == "API":
36
+ if 'R1' in selected_model_name:
37
+ is_r1_model = True
38
+ else:
39
+ model_identifier = MODEL_PATHS.get(selected_model_name, ["", ""])[1]
40
+ if 'R1' in model_identifier:
41
+ is_r1_model = True
42
+
43
+ if not is_r1_model:
44
+ return content
45
+
46
+ # Fully closed <think>...</think> block
47
+ if '<think>' in content and '</think>' in content:
48
+ # Using re.split is more robust than finding indices
49
+ parts = re.split(r'(</think>)', content, 1)
50
+ think_part = parts[0] + parts[1] # All content from <think> to </think>
51
+ after_think_part = parts[2] if len(parts) > 2 else ""
52
+
53
+ # 1. Process the think part
54
+ processed_think = re.sub(
55
+ r'(<think>)(.*?)(</think>)',
56
+ r'<details style="font-style: italic; background: rgba(222, 222, 222, 0.5); padding: 10px; border-radius: 10px;"><summary style="font-weight:bold;">Reasoning (Click to expand)</summary>\2</details>',
57
+ think_part,
58
+ flags=re.DOTALL
59
+ )
60
+
61
+ # 2. Process the part after </think>, filtering <answer> tags
62
+ # Using re.sub to replace <answer> and </answer> with an empty string
63
+ processed_after_think = re.sub(r'</?answer>', '', after_think_part)
64
+
65
+ # 3. Concatenate the results
66
+ return processed_think + processed_after_think
67
+
68
+ # Only an opening <think>, indicating reasoning is in progress
69
+ if '<think>' in content and '</think>' not in content:
70
+ return re.sub(
71
+ r'<think>(.*?)$',
72
+ r'<details open style="font-style: italic; background: rgba(222, 222, 222, 0.5); padding: 10px; border-radius: 10px;"><summary style="font-weight:bold;">Reasoning...</summary>\1</details>',
73
+ content,
74
+ flags=re.DOTALL
75
+ )
76
+
77
+ # This case should be rare in streaming output, but kept for completeness
78
+ if '<think>' not in content and '</think>' in content:
79
+ # Also need to process content after </think>
80
+ parts = re.split(r'(</think>)', content, 1)
81
+ think_part = parts[0] + parts[1]
82
+ after_think_part = parts[2] if len(parts) > 2 else ""
83
+
84
+ processed_think = re.sub(
85
+ r'(.*?)</think>',
86
+ r'<details style="font-style: italic; background: rgba(222, 222, 222, 0.5); padding: 10px; border-radius: 10px;"><summary style="font-weight:bold;">Reasoning (Click to expand)</summary>\1</details>',
87
+ think_part,
88
+ flags=re.DOTALL
89
+ )
90
+ processed_after_think = re.sub(r'</?answer>', '', after_think_part)
91
+
92
+ return processed_think + processed_after_think
93
+
94
+ # If there are no <think> tags, return the content directly
95
+ return content
96
+
97
+
98
+ def setup_seed(seed):
99
+ random.seed(seed)
100
+ np.random.seed(seed)
101
+ torch.manual_seed(seed)
102
+ if device != "cpu":
103
+ torch.cuda.manual_seed(seed)
104
+ torch.cuda.manual_seed_all(seed)
105
+ torch.backends.cudnn.deterministic = True
106
+ torch.backends.cudnn.benchmark = False
107
+
108
+ # --- Gradio App Logic ---
109
+
110
+ # Gradio uses global variables or functions to load models, similar to st.cache_resource
111
+ # We cache models and tokenizers in a dictionary to avoid reloading
112
+ loaded_models = {}
113
+
114
+ def load_model_tokenizer_gradio(model_name):
115
+ """
116
+ Gradio version of the model loading function with caching.
117
+ """
118
+ if model_name in loaded_models:
119
+ # print(f"Using cached model: {model_name}")
120
+ return loaded_models[model_name]
121
+
122
+ # print(f"Loading model: {model_name}...")
123
+ model_path = MODEL_PATHS[model_name][0]
124
+ model = AutoModelForCausalLM.from_pretrained(
125
+ model_path,
126
+ trust_remote_code=True,
127
+ cache_dir=".cache",
128
+ ).to(device).eval()
129
+ tokenizer = AutoTokenizer.from_pretrained(
130
+ model_path,
131
+ trust_remote_code=True,
132
+ cache_dir=".cache",
133
+ )
134
+ loaded_models[model_name] = (model, tokenizer)
135
+ print("Model loaded.")
136
+ return model, tokenizer
137
+
138
+ @spaces.GPU
139
+ def chat_fn(
140
+ user_message,
141
+ history,
142
+ model_source,
143
+ # Local model settings
144
+ selected_model,
145
+ # API settings
146
+ api_url,
147
+ api_model_id,
148
+ api_model_name,
149
+ api_key,
150
+ # Generation parameters
151
+ history_chat_num,
152
+ max_new_tokens,
153
+ temperature
154
+ ):
155
+ """
156
+ Gradio's core chat processing function.
157
+ It receives the current values of all UI components as input.
158
+ """
159
+ history = history or []
160
+
161
+ # Build context for the model based on the passed, unmodified history
162
+ chat_messages_for_model = []
163
+ # Limit the number of history turns
164
+ if history_chat_num > 0 and len(history) > history_chat_num:
165
+ relevant_history_turns = history[-history_chat_num:]
166
+ else:
167
+ relevant_history_turns = history
168
+
169
+ for user_msg, assistant_msg in relevant_history_turns:
170
+ chat_messages_for_model.append({"role": "user", "content": user_msg})
171
+ if assistant_msg:
172
+ chat_messages_for_model.append({"role": "assistant", "content": assistant_msg})
173
+
174
+ # Add the current user message to the model's context
175
+ chat_messages_for_model.append({"role": "user", "content": user_message})
176
+
177
+ final_chat_messages = system_prompt + chat_messages_for_model
178
+
179
+ # Now, update the history for UI display
180
+ history.extend([*chat_messages_for_model, {"role": "assistant", "content": user_message}])
181
+
182
+ # --- Model Invocation ---
183
+ if model_source == "API":
184
+ try:
185
+ from openai import OpenAI
186
+ client = OpenAI(api_key=api_key, base_url=api_url)
187
+
188
+ response = client.chat.completions.create(
189
+ model=api_model_id,
190
+ messages=final_chat_messages,
191
+ stream=True,
192
+ temperature=temperature
193
+ )
194
+
195
+ answer = ""
196
+ for chunk in response:
197
+ content = chunk.choices[0].delta.content or ""
198
+ answer += content
199
+ processed_answer = process_assistant_content(answer, model_source, api_model_name)
200
+ history[-1]["content"] = processed_answer
201
+ yield history, history
202
+
203
+ except Exception as e:
204
+ history[-1]["content"] = f"API call error: {str(e)}"
205
+ yield history, history
206
+
207
+ else: # Local Model
208
+ try:
209
+ model, tokenizer = load_model_tokenizer_gradio(selected_model)
210
+
211
+ random_seed = random.randint(0, 2**32 - 1)
212
+ setup_seed(random_seed)
213
+
214
+ new_prompt = tokenizer.apply_chat_template(
215
+ final_chat_messages,
216
+ tokenize=False,
217
+ add_generation_prompt=True
218
+ )
219
+
220
+ inputs = tokenizer(new_prompt, return_tensors="pt", truncation=True).to(device)
221
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
222
+
223
+ generation_kwargs = {
224
+ "input_ids": inputs.input_ids,
225
+ "attention_mask": inputs.attention_mask,
226
+ "max_new_tokens": max_new_tokens,
227
+ "num_return_sequences": 1,
228
+ "do_sample": True,
229
+ "pad_token_id": tokenizer.pad_token_id,
230
+ "eos_token_id": tokenizer.eos_token_id,
231
+ "temperature": temperature,
232
+ "top_p": 0.85,
233
+ "streamer": streamer,
234
+ }
235
+
236
+ thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
237
+ thread.start()
238
+
239
+ answer = ""
240
+ for new_text in streamer:
241
+ answer += new_text
242
+ processed_answer = process_assistant_content(answer, model_source, selected_model)
243
+ history[-1]["content"] = processed_answer
244
+ yield history, history
245
+ except Exception as e:
246
+ history[-1]["content"] = f"Local model call error: {str(e)}"
247
+ yield history, history
248
+
249
+ # --- Gradio UI Layout ---
250
+ css = """
251
+ .gradio-container { font-family: 'sans-serif'; }
252
+ footer { display: none !important; }
253
+ """
254
+ image_url = "https://chunte-hfba.static.hf.space/images/modern%20Huggies/Huggy%20Sunny%20hello.png"
255
+
256
+ # Define example data
257
+ prompt_datas = [
258
+ '请介绍一下自己。',
259
+ '你更擅长哪一个学科?',
260
+ '鲁迅的《狂人日记》是如何批判封建礼教的?',
261
+ '我咳嗽已经持续了两周,需要去医院检查吗?',
262
+ '详细的介绍光速的物理概念。',
263
+ '推荐一些杭州的特色美食吧。',
264
+ '请为我讲解“大语言模型”这个概念。',
265
+ '如何理解ChatGPT?',
266
+ 'Introduce the history of the United States, please.'
267
+ ]
268
+
269
+ with gr.Blocks(theme='soft', css=css) as demo:
270
+ # History state, this is the Gradio equivalent of st.session_state
271
+ chat_history = gr.State([])
272
+ chat_input_cache = gr.State("")
273
+
274
+ # Top Title and Badge
275
+ title_html = """
276
+ <div style="text-align: center;">
277
+ <h1>Embformer: An Embedding-Weight-Only Transformer Architecture</h1>
278
+ <div style="display: flex; justify-content: center; align-items: center; gap: 8px; margin-top: 10px;">
279
+ <a href="https://doi.org/10.5281/zenodo.15736957">
280
+ <img src="https://img.shields.io/badge/DOI-10.5281%2Fzenodo.15736957-blue.svg" alt="DOI">
281
+ </a>
282
+ <a href="https://github.com/HighCWu/embformer">
283
+ <img src="https://img.shields.io/badge/github-%23121011.svg?style=for-the-badge&logo=github&logoColor=white" alt="code">
284
+ </a>
285
+ <a href="https://huggingface.co/collections/HighCWu/embformer-minimind-685be74dc761610439241bd5">
286
+ <img src="https://img.shields.io/badge/Model-🤗-yellow" alt="model">
287
+ </a>
288
+ </div>
289
+ </div>
290
+ """
291
+ gr.HTML(title_html)
292
+ gr.Markdown("""
293
+ This is the official demo of [Embformer: An Embedding-Weight-Only Transformer Architecture](https://doi.org/10.5281/zenodo.15736957).
294
+
295
+ **Note**: Since the model dataset used in this demo is derived from the MiniMind dataset, which contains a large proportion of Chinese content, please try to use Chinese as much as possible in the conversation.
296
+ """)
297
+
298
+ with gr.Row():
299
+ with gr.Column(scale=1, min_width=200):
300
+ gr.Markdown("### Model Settings")
301
+
302
+ # Model source switcher
303
+ model_source_radio = gr.Radio(["Local Model", "API"], value="Local Model", label="Select Model Source", visible=False)
304
+
305
+ # Local model settings
306
+ with gr.Group(visible=True) as local_model_group:
307
+ selected_model_dd = gr.Dropdown(
308
+ list(MODEL_PATHS.keys()),
309
+ value="Embformer-MiniMind (0.1B)",
310
+ label="Select Local Model"
311
+ )
312
+
313
+ # API settings
314
+ with gr.Group(visible=False) as api_model_group:
315
+ api_url_tb = gr.Textbox("http://127.0.0.1:8000/v1", label="API URL")
316
+ api_model_id_tb = gr.Textbox("embformer-minimind", label="Model ID")
317
+ api_model_name_tb = gr.Textbox("Embformer-MiniMind (0.1B)", label="Model Name (for feature detection)")
318
+ api_key_tb = gr.Textbox("none", label="API Key", type="password")
319
+
320
+ # Common generation parameters
321
+ history_chat_num_slider = gr.Slider(0, 6, value=0, step=2, label="History Turns")
322
+ max_new_tokens_slider = gr.Slider(256, 8192, value=1024, step=1, label="Max New Tokens")
323
+ temperature_slider = gr.Slider(0.6, 1.2, value=0.85, step=0.01, label="Temperature")
324
+
325
+ # Clear history button
326
+ clear_btn = gr.Button("🗑️ Clear History")
327
+
328
+ with gr.Column(scale=4):
329
+ gr.Markdown("### Chat")
330
+
331
+ chatbot = gr.Chatbot(
332
+ [],
333
+ elem_id="chatbot",
334
+ avatar_images=(None, image_url),
335
+ type="messages",
336
+ height=350
337
+ )
338
+ chat_input = gr.Textbox(
339
+ show_label=False,
340
+ placeholder="Send a message to MiniMind... (Enter to send)",
341
+ container=False,
342
+ scale=7,
343
+ elem_id="chat-textbox",
344
+ )
345
+ examples = gr.Examples(
346
+ examples=prompt_datas,
347
+ inputs=chat_input, # After clicking, the example content will fill chat_input
348
+ label="Click an example to ask (will automatically clear chat and continue)"
349
+ )
350
+
351
+ # --- Event Listeners and Bindings ---
352
+
353
+ # Show/hide corresponding setting groups when switching model source
354
+ def toggle_model_source_ui(source):
355
+ return {
356
+ local_model_group: gr.update(visible=source == "Local Model"),
357
+ api_model_group: gr.update(visible=source == "API")
358
+ }
359
+ model_source_radio.change(
360
+ fn=toggle_model_source_ui,
361
+ inputs=model_source_radio,
362
+ outputs=[local_model_group, api_model_group]
363
+ )
364
+
365
+ # Define the list of input components for the submit event
366
+ submit_inputs = [
367
+ chat_input_cache, chat_history, model_source_radio, selected_model_dd,
368
+ api_url_tb, api_model_id_tb, api_model_name_tb, api_key_tb,
369
+ history_chat_num_slider, max_new_tokens_slider, temperature_slider
370
+ ]
371
+
372
+ # When chat_input is submitted (user presses enter or an example is clicked), run chat_fn
373
+ submit_event = chat_input.submit(
374
+ fn=lambda text: ("", text),
375
+ inputs=chat_input,
376
+ outputs=[chat_input, chat_input_cache],
377
+ ).then(
378
+ fn=chat_fn,
379
+ inputs=submit_inputs,
380
+ outputs=[chatbot, chat_history],
381
+ )
382
+
383
+ # Event chain for clicking an example
384
+ examples.load_input_event.then(
385
+ fn=lambda text: ("", text, [], []), # A function to clear the history
386
+ inputs=chat_input,
387
+ outputs=[chat_input, chat_input_cache, chatbot, chat_history], # This affects the chatbot and chat_history
388
+ ).then(
389
+ fn=chat_fn, # Use the dedicated run_example function
390
+ inputs=submit_inputs, # Pass example text and other settings
391
+ outputs=[chatbot, chat_history],
392
+ )
393
+
394
+ # Clear history button logic
395
+ def clear_history():
396
+ return [], []
397
+ clear_btn.click(fn=clear_history, outputs=[chatbot, chat_history])
398
+ chatbot.clear(fn=clear_history, outputs=[chatbot, chat_history])
399
+
400
+
401
+ if __name__ == "__main__":
402
+ # Pre-load the default model on startup
403
+ print("Pre-loading default model...")
404
+ load_model_tokenizer_gradio("Embformer-MiniMind (0.1B)")
405
+
406
+ # Launch the Gradio app
407
+ demo.queue().launch(share=False)
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ transformers @ git+https://github.com/huggingface/transformers.git@cb0f604
2
+ gradio<=5.23.0
3
+ spaces<=0.37.1