Spestly commited on
Commit
c7c3bd5
·
verified ·
1 Parent(s): 50da5e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -36
app.py CHANGED
@@ -3,6 +3,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
  import time
5
  import spaces
 
6
 
7
  # Model configurations
8
  MODELS = {
@@ -44,10 +45,9 @@ def generate_response(model_id, conversation, user_message, max_length=512, temp
44
  )
45
  messages.append({"role": "system", "content": system_prompt})
46
 
47
- # Add conversation history (OpenAI-style)
48
  for msg in conversation:
49
- if msg["role"] in ("user", "assistant"):
50
- messages.append({"role": msg["role"], "content": msg["content"]})
51
 
52
  # Add current user message
53
  messages.append({"role": "user", "content": user_message})
@@ -76,33 +76,61 @@ def generate_response(model_id, conversation, user_message, max_length=512, temp
76
  outputs[0][inputs['input_ids'].shape[-1]:],
77
  skip_special_tokens=True
78
  ).strip()
 
79
  return response, load_time, generation_time
80
 
81
- def respond(history, message, model_name, max_length, temperature):
82
- """Main function for custom Chatbot interface"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  if not message.strip():
84
- history = history + [["user", message], ["assistant", "Please enter a message"]]
85
- return history, ""
86
- model_id = MODELS.get(model_name, MODELS["Athena-R3X 8B"])
87
  try:
88
- # Format history for Athena
89
- formatted_history = []
90
- for i in range(0, len(history), 2):
91
- if i < len(history):
92
- user_msg = history[i][1] if history[i][0] == "user" else ""
93
- assistant_msg = history[i+1][1] if i+1 < len(history) and history[i+1][0] == "assistant" else ""
94
- if user_msg:
95
- formatted_history.append({"role": "user", "content": user_msg})
96
- if assistant_msg:
97
- formatted_history.append({"role": "assistant", "content": assistant_msg})
98
  response, load_time, generation_time = generate_response(
99
- model_id, formatted_history, message, max_length, temperature
100
  )
101
- history = history + [["user", message], ["assistant", response]]
102
- return history, ""
 
 
 
 
 
 
 
 
 
 
103
  except Exception as e:
104
- history = history + [["user", message], ["assistant", f"Error: {str(e)}"]]
105
- return history, ""
 
106
 
107
  css = """
108
  .message {
@@ -110,22 +138,73 @@ css = """
110
  margin: 5px;
111
  border-radius: 10px;
112
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  """
114
 
115
  theme = gr.themes.Monochrome()
116
 
117
- with gr.Blocks(title="Athena Playground Chat", css=css, theme=theme) as demo:
118
  gr.Markdown("# 🚀 Athena Playground Chat")
119
  gr.Markdown("*Powered by HuggingFace ZeroGPU*")
120
 
121
- chatbot = gr.Chatbot(height=500, label="Athena")
122
- state = gr.State([]) # chat history
123
-
 
 
124
  with gr.Row():
125
- user_input = gr.Textbox(label="Your message", scale=8, autofocus=True)
126
- send_btn = gr.Button(value="Send", scale=1)
127
 
128
- # --- Configuration controls at the bottom ---
129
  gr.Markdown("### ⚙️ Model & Generation Settings")
130
  with gr.Row():
131
  model_choice = gr.Dropdown(
@@ -135,7 +214,7 @@ with gr.Blocks(title="Athena Playground Chat", css=css, theme=theme) as demo:
135
  info="Select which Athena model to use"
136
  )
137
  max_length = gr.Slider(
138
- 32, 2048, value=512,
139
  label="📝 Max Tokens",
140
  info="Maximum number of tokens to generate"
141
  )
@@ -145,14 +224,35 @@ with gr.Blocks(title="Athena Playground Chat", css=css, theme=theme) as demo:
145
  info="Higher values = more creative responses"
146
  )
147
 
148
- def chat_submit(history, message, model_name, max_length, temperature):
149
- return respond(history, message, model_name, max_length, temperature)
150
-
 
 
 
 
151
  send_btn.click(
152
  chat_submit,
153
- inputs=[state, user_input, model_choice, max_length, temperature],
154
- outputs=[chatbot, user_input]
155
  )
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  if __name__ == "__main__":
158
  demo.launch()
 
3
  import torch
4
  import time
5
  import spaces
6
+ import re
7
 
8
  # Model configurations
9
  MODELS = {
 
45
  )
46
  messages.append({"role": "system", "content": system_prompt})
47
 
48
+ # Add conversation history
49
  for msg in conversation:
50
+ messages.append(msg)
 
51
 
52
  # Add current user message
53
  messages.append({"role": "user", "content": user_message})
 
76
  outputs[0][inputs['input_ids'].shape[-1]:],
77
  skip_special_tokens=True
78
  ).strip()
79
+ print(f"Generation time: {generation_time:.2f}s")
80
  return response, load_time, generation_time
81
 
82
+ def format_response_with_thinking(response):
83
+ """Format response to handle <think></think> tags"""
84
+ # Check if response contains thinking tags
85
+ if '<think>' in response and '</think>' in response:
86
+ # Split the response into parts
87
+ pattern = r'(.*?)(<think>(.*?)</think>)(.*)'
88
+ match = re.search(pattern, response, re.DOTALL)
89
+
90
+ if match:
91
+ before_thinking = match.group(1).strip()
92
+ thinking_content = match.group(3).strip()
93
+ after_thinking = match.group(4).strip()
94
+
95
+ # Create HTML with collapsible thinking section
96
+ html = f"{before_thinking}\n"
97
+ html += f'<div class="thinking-container">'
98
+ html += f'<button class="thinking-toggle" onclick="this.nextElementSibling.classList.toggle(\'hidden\'); this.textContent = this.textContent === \'Show reasoning\' ? \'Hide reasoning\' : \'Show reasoning\'">Show reasoning</button>'
99
+ html += f'<div class="thinking-content hidden">{thinking_content}</div>'
100
+ html += f'</div>\n'
101
+ html += after_thinking
102
+
103
+ return html
104
+
105
+ # If no thinking tags, return the original response
106
+ return response
107
+
108
+ def chat_submit(message, chat_history, conversation_state, model_name, max_length, temperature):
109
+ """Process a new message and update the chat history"""
110
  if not message.strip():
111
+ return "", chat_history, conversation_state
112
+
113
+ model_id = MODELS.get(model_name, MODELS["Athena-R3X 4B"])
114
  try:
 
 
 
 
 
 
 
 
 
 
115
  response, load_time, generation_time = generate_response(
116
+ model_id, conversation_state, message, max_length, temperature
117
  )
118
+
119
+ # Update the conversation state with the raw response
120
+ conversation_state.append({"role": "user", "content": message})
121
+ conversation_state.append({"role": "assistant", "content": response})
122
+
123
+ # Format the response for display
124
+ formatted_response = format_response_with_thinking(response)
125
+
126
+ # Update the visible chat history
127
+ chat_history.append((message, formatted_response))
128
+
129
+ return "", chat_history, conversation_state
130
  except Exception as e:
131
+ error_message = f"Error: {str(e)}"
132
+ chat_history.append((message, error_message))
133
+ return "", chat_history, conversation_state
134
 
135
  css = """
136
  .message {
 
138
  margin: 5px;
139
  border-radius: 10px;
140
  }
141
+
142
+ .thinking-container {
143
+ margin: 10px 0;
144
+ }
145
+
146
+ .thinking-toggle {
147
+ background-color: #f1f1f1;
148
+ border: 1px solid #ddd;
149
+ border-radius: 4px;
150
+ padding: 5px 10px;
151
+ cursor: pointer;
152
+ font-size: 0.9em;
153
+ margin-bottom: 5px;
154
+ color: #555;
155
+ }
156
+
157
+ .thinking-content {
158
+ background-color: #f9f9f9;
159
+ border-left: 3px solid #ccc;
160
+ padding: 10px;
161
+ margin-top: 5px;
162
+ font-size: 0.95em;
163
+ color: #555;
164
+ font-family: monospace;
165
+ white-space: pre-wrap;
166
+ overflow-x: auto;
167
+ }
168
+
169
+ .hidden {
170
+ display: none;
171
+ }
172
+ """
173
+
174
+ # Add JavaScript to handle the toggle functionality
175
+ js = """
176
+ function setupThinkingToggles() {
177
+ document.querySelectorAll('.thinking-toggle').forEach(button => {
178
+ button.addEventListener('click', function() {
179
+ const content = this.nextElementSibling;
180
+ content.classList.toggle('hidden');
181
+ this.textContent = content.classList.contains('hidden') ? 'Show reasoning' : 'Hide reasoning';
182
+ });
183
+ });
184
+ }
185
+
186
+ // Run after the page loads and when the chat updates
187
+ document.addEventListener('DOMContentLoaded', setupThinkingToggles);
188
+ const observer = new MutationObserver(setupThinkingToggles);
189
+ observer.observe(document.body, { childList: true, subtree: true });
190
  """
191
 
192
  theme = gr.themes.Monochrome()
193
 
194
+ with gr.Blocks(title="Athena Playground Chat", css=css, theme=theme, js=js) as demo:
195
  gr.Markdown("# 🚀 Athena Playground Chat")
196
  gr.Markdown("*Powered by HuggingFace ZeroGPU*")
197
 
198
+ # State to keep track of the conversation for the model
199
+ conversation_state = gr.State([])
200
+
201
+ chatbot = gr.Chatbot(height=500, label="Athena", render_markdown=True)
202
+
203
  with gr.Row():
204
+ user_input = gr.Textbox(label="Your message", scale=8, autofocus=True, placeholder="Type your message here...")
205
+ send_btn = gr.Button(value="Send", scale=1, variant="primary")
206
 
207
+ # Configuration controls
208
  gr.Markdown("### ⚙️ Model & Generation Settings")
209
  with gr.Row():
210
  model_choice = gr.Dropdown(
 
214
  info="Select which Athena model to use"
215
  )
216
  max_length = gr.Slider(
217
+ 32, 8000, value=512,
218
  label="📝 Max Tokens",
219
  info="Maximum number of tokens to generate"
220
  )
 
224
  info="Higher values = more creative responses"
225
  )
226
 
227
+ # Connect the interface components
228
+ submit_event = user_input.submit(
229
+ chat_submit,
230
+ inputs=[user_input, chatbot, conversation_state, model_choice, max_length, temperature],
231
+ outputs=[user_input, chatbot, conversation_state]
232
+ )
233
+
234
  send_btn.click(
235
  chat_submit,
236
+ inputs=[user_input, chatbot, conversation_state, model_choice, max_length, temperature],
237
+ outputs=[user_input, chatbot, conversation_state]
238
  )
239
 
240
+ # Add examples if desired
241
+ gr.Examples(
242
+ examples=[
243
+ "What is artificial intelligence?",
244
+ "Can you explain quantum computing?",
245
+ "Write a short poem about technology",
246
+ "What are some ethical concerns about AI?"
247
+ ],
248
+ inputs=[user_input]
249
+ )
250
+
251
+ gr.Markdown("""
252
+ ### About the Thinking Tags
253
+ Some Athena models (particularly R3X series) include reasoning in `<think></think>` tags.
254
+ Click "Show reasoning" to see the model's thought process behind its answers.
255
+ """)
256
+
257
  if __name__ == "__main__":
258
  demo.launch()