WillHeld commited on
Commit
752950e
·
verified ·
1 Parent(s): b0dd995

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +211 -96
app.py CHANGED
@@ -1,38 +1,43 @@
1
- import spaces
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
3
- import gradio as gr
4
- from threading import Thread
5
  from datetime import datetime, timedelta
 
 
 
 
 
 
6
  from datasets import Dataset
7
  from huggingface_hub import HfApi, login
8
- import uuid
9
- import os
10
- import time
11
 
 
12
  checkpoint = "WillHeld/soft-raccoon"
13
  device = "cuda"
14
- tokenizer = AutoTokenizer.from_pretrained(checkpoint)
15
- model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)
16
 
17
  # Dataset configuration
18
- DATASET_NAME = "WillHeld/soft-raccoon-conversations" # Change to your HF username
19
- PUSH_TO_HUB = True # Set to False if you just want to save locally first
20
-
21
- # Time-based storage settings
22
- SAVE_INTERVAL_MINUTES = 5 # Save every 5 minutes
23
  last_save_time = datetime.now()
24
 
25
- # Initialize storage for conversations
 
 
 
 
 
26
  conversations = []
27
 
28
- # Login to Huggingface Hub (you'll need to set HF_TOKEN env var or use login())
29
- # Uncomment the below line to login with your token
30
- login(token=os.environ.get("HF_TOKEN"))
 
31
 
32
  def save_to_dataset():
33
  """Save the current conversations to a HuggingFace dataset"""
34
  if not conversations:
35
- return None
36
 
37
  # Convert conversations to dataset format
38
  dataset_dict = {
@@ -45,43 +50,55 @@ def save_to_dataset():
45
  for conv in conversations:
46
  dataset_dict["conversation_id"].append(conv["conversation_id"])
47
  dataset_dict["timestamp"].append(conv["timestamp"])
48
- dataset_dict["messages"].append(conv["messages"])
49
- dataset_dict["metadata"].append(conv["metadata"])
50
 
51
  # Create dataset
52
  dataset = Dataset.from_dict(dataset_dict)
53
 
54
- if PUSH_TO_HUB:
55
- try:
56
- # Push to hub - will create the dataset if it doesn't exist
57
- dataset.push_to_hub(DATASET_NAME)
58
- print(f"Successfully pushed {len(conversations)} conversations to {DATASET_NAME}")
59
- except Exception as e:
60
- print(f"Error pushing to hub: {e}")
61
- # Save locally as fallback
62
- dataset.save_to_disk("local_dataset")
63
- else:
64
- # Save locally
65
- dataset.save_to_disk("local_dataset")
66
- print(f"Saved {len(conversations)} conversations locally to 'local_dataset'")
67
 
68
- return dataset
 
69
 
70
- @spaces.GPU(duration=120)
71
- def predict(message, history, temperature, top_p, conversation_id=None):
72
- # Create or retrieve conversation ID for tracking
73
- if conversation_id is None:
74
  conversation_id = str(uuid.uuid4())
75
 
76
- # Update history with user message
77
- history.append({"role": "user", "content": message})
78
- input_text = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
80
 
81
- # Create a streamer
82
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
83
 
84
- # Set up generation parameters
85
  generation_kwargs = {
86
  "input_ids": inputs,
87
  "max_new_tokens": 1024,
@@ -91,102 +108,200 @@ def predict(message, history, temperature, top_p, conversation_id=None):
91
  "streamer": streamer,
92
  }
93
 
94
- # Run generation in a separate thread
95
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
96
  thread.start()
97
 
98
- # Yield from the streamer as tokens are generated
99
  partial_text = ""
 
 
100
  for new_text in streamer:
101
  partial_text += new_text
102
- yield partial_text
103
-
104
- # After generation completes, update history with assistant response
105
- history.append({"role": "assistant", "content": partial_text})
106
 
107
  # Store conversation data
108
- # Check if we already have this conversation
109
  existing_conv = next((c for c in conversations if c["conversation_id"] == conversation_id), None)
110
 
 
 
 
 
 
111
  if existing_conv:
112
  # Update existing conversation
113
- existing_conv["messages"] = history
114
- existing_conv["metadata"]["last_updated"] = datetime.now().isoformat()
 
 
115
  else:
116
  # Create new conversation record
117
  conversations.append({
118
  "conversation_id": conversation_id,
119
- "timestamp": datetime.now().isoformat(),
120
- "messages": history,
121
  "metadata": {
122
  "model": checkpoint,
123
  "temperature": temperature,
124
  "top_p": top_p,
125
- "last_updated": datetime.now().isoformat()
126
  }
127
  })
128
 
129
  # Check if it's time to save based on elapsed time
130
  global last_save_time
131
- current_time = datetime.now()
132
- if current_time - last_save_time > timedelta(minutes=SAVE_INTERVAL_MINUTES):
133
  save_to_dataset()
134
- last_save_time = current_time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
- return partial_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
- def save_dataset_button():
139
- """Manually save the current dataset"""
140
- dataset = save_to_dataset()
141
- if dataset:
142
- return f"Saved {len(conversations)} conversations to dataset."
143
- return "No conversations to save."
144
 
145
- with gr.Blocks() as demo:
146
- conversation_id = gr.State(None)
 
147
 
148
  with gr.Row():
149
  with gr.Column(scale=3):
150
- chatbot = gr.ChatInterface(
151
- predict,
152
- additional_inputs=[
153
- gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"),
154
- gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P"),
155
- conversation_id
156
- ],
157
- type="messages"
158
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
  with gr.Column(scale=1):
161
  with gr.Group():
162
  gr.Markdown("### Dataset Controls")
163
- save_button = gr.Button("Save conversations to dataset")
164
- save_output = gr.Textbox(label="Save Status")
165
 
166
- # Display current conversation count
167
- conversation_count = gr.Number(value=lambda: len(conversations),
168
- label="Total Conversations",
169
- interactive=False)
170
 
171
- # Display time until next auto-save
172
- next_save_time = gr.Textbox(label="Next Auto-Save",
173
- value=lambda: f"In {SAVE_INTERVAL_MINUTES - (datetime.now() - last_save_time).seconds // 60} minutes")
174
- refresh_button = gr.Button("Refresh Stats")
175
 
176
  # Set up event handlers
177
- save_button.click(save_dataset_button, outputs=save_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
- def refresh_stats():
180
- mins_until_save = SAVE_INTERVAL_MINUTES - (datetime.now() - last_save_time).seconds // 60
181
- return len(conversations), f"In {mins_until_save} minutes"
 
 
 
 
182
 
183
- refresh_button.click(refresh_stats, outputs=[conversation_count, next_save_time])
 
 
184
 
185
- # Save on shutdown
186
- demo.on_close(save_to_dataset)
 
187
 
188
- # Set up periodic UI refresh (every 60 seconds)
189
- gr.Timer(60, lambda: None).start()
190
 
 
191
  if __name__ == "__main__":
192
- demo.launch()
 
 
1
+ import os
2
+ import uuid
3
+ import time
4
+ import json
5
  from datetime import datetime, timedelta
6
+ from threading import Thread
7
+
8
+ # Gradio and HuggingFace imports
9
+ import gradio as gr
10
+ from gradio.themes import Base
11
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
12
  from datasets import Dataset
13
  from huggingface_hub import HfApi, login
 
 
 
14
 
15
+ # Model configuration
16
  checkpoint = "WillHeld/soft-raccoon"
17
  device = "cuda"
 
 
18
 
19
  # Dataset configuration
20
+ DATASET_NAME = "your-username/soft-raccoon-conversations" # Change to your username
21
+ SAVE_INTERVAL_MINUTES = 5 # Save data every 5 minutes
 
 
 
22
  last_save_time = datetime.now()
23
 
24
+ # Initialize model and tokenizer
25
+ print(f"Loading model from {checkpoint}...")
26
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
27
+ model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)
28
+
29
+ # Data storage
30
  conversations = []
31
 
32
+ # Hugging Face authentication
33
+ # Uncomment this line to login with your token
34
+ # login(token=os.environ.get("HF_TOKEN"))
35
+
36
 
37
  def save_to_dataset():
38
  """Save the current conversations to a HuggingFace dataset"""
39
  if not conversations:
40
+ return None, f"No conversations to save. Last attempt: {datetime.now().strftime('%H:%M:%S')}"
41
 
42
  # Convert conversations to dataset format
43
  dataset_dict = {
 
50
  for conv in conversations:
51
  dataset_dict["conversation_id"].append(conv["conversation_id"])
52
  dataset_dict["timestamp"].append(conv["timestamp"])
53
+ dataset_dict["messages"].append(json.dumps(conv["messages"]))
54
+ dataset_dict["metadata"].append(json.dumps(conv["metadata"]))
55
 
56
  # Create dataset
57
  dataset = Dataset.from_dict(dataset_dict)
58
 
59
+ try:
60
+ # Push to hub
61
+ dataset.push_to_hub(DATASET_NAME)
62
+ status_msg = f"Successfully saved {len(conversations)} conversations to {DATASET_NAME}"
63
+ print(status_msg)
64
+ except Exception as e:
65
+ # Save locally as fallback
66
+ local_path = f"local_dataset_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
67
+ dataset.save_to_disk(local_path)
68
+ status_msg = f"Error pushing to hub: {str(e)}. Saved locally to '{local_path}'"
69
+ print(status_msg)
 
 
70
 
71
+ return dataset, status_msg
72
+
73
 
74
+ def predict(message, chat_history, temperature, top_p, conversation_id=None):
75
+ """Generate a response using the model and save the conversation"""
76
+ # Create/retrieve conversation ID for tracking
77
+ if conversation_id is None or conversation_id == "":
78
  conversation_id = str(uuid.uuid4())
79
 
80
+ # Format chat history for the model
81
+ formatted_history = []
82
+ for human_msg, ai_msg in chat_history:
83
+ formatted_history.append({"role": "user", "content": human_msg})
84
+ if ai_msg: # Skip None values that might occur during streaming
85
+ formatted_history.append({"role": "assistant", "content": ai_msg})
86
+
87
+ # Add the current message
88
+ formatted_history.append({"role": "user", "content": message})
89
+
90
+ # Prepare input for the model
91
+ input_text = tokenizer.apply_chat_template(
92
+ formatted_history,
93
+ tokenize=False,
94
+ add_generation_prompt=True
95
+ )
96
  inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
97
 
98
+ # Set up streaming
99
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
100
 
101
+ # Generation parameters
102
  generation_kwargs = {
103
  "input_ids": inputs,
104
  "max_new_tokens": 1024,
 
108
  "streamer": streamer,
109
  }
110
 
111
+ # Generate in a separate thread
112
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
113
  thread.start()
114
 
115
+ # Initialize response
116
  partial_text = ""
117
+
118
+ # Yield partial text as it's generated
119
  for new_text in streamer:
120
  partial_text += new_text
121
+ yield chat_history + [[message, partial_text]], conversation_id
 
 
 
122
 
123
  # Store conversation data
 
124
  existing_conv = next((c for c in conversations if c["conversation_id"] == conversation_id), None)
125
 
126
+ # Update history with final response
127
+ formatted_history.append({"role": "assistant", "content": partial_text})
128
+
129
+ # Update or create conversation record
130
+ current_time = datetime.now().isoformat()
131
  if existing_conv:
132
  # Update existing conversation
133
+ existing_conv["messages"] = formatted_history
134
+ existing_conv["metadata"]["last_updated"] = current_time
135
+ existing_conv["metadata"]["temperature"] = temperature
136
+ existing_conv["metadata"]["top_p"] = top_p
137
  else:
138
  # Create new conversation record
139
  conversations.append({
140
  "conversation_id": conversation_id,
141
+ "timestamp": current_time,
142
+ "messages": formatted_history,
143
  "metadata": {
144
  "model": checkpoint,
145
  "temperature": temperature,
146
  "top_p": top_p,
147
+ "last_updated": current_time
148
  }
149
  })
150
 
151
  # Check if it's time to save based on elapsed time
152
  global last_save_time
153
+ current_time_dt = datetime.now()
154
+ if current_time_dt - last_save_time > timedelta(minutes=SAVE_INTERVAL_MINUTES):
155
  save_to_dataset()
156
+ last_save_time = current_time_dt
157
+
158
+ return chat_history + [[message, partial_text]], conversation_id
159
+
160
+
161
+ def save_dataset_manually():
162
+ """Manually trigger dataset save"""
163
+ _, status = save_to_dataset()
164
+ return status
165
+
166
+
167
+ def get_stats():
168
+ """Get current stats about conversations and saving"""
169
+ mins_until_save = SAVE_INTERVAL_MINUTES - (datetime.now() - last_save_time).seconds // 60
170
+ if mins_until_save < 0:
171
+ mins_until_save = 0
172
 
173
+ return {
174
+ "conversation_count": len(conversations),
175
+ "next_save": f"In {mins_until_save} minutes",
176
+ "last_save": last_save_time.strftime('%H:%M:%S'),
177
+ "dataset_name": DATASET_NAME
178
+ }
179
+
180
+
181
+ # Create a custom Stanford theme
182
+ class StanfordTheme(gr.Theme):
183
+ def __init__(self):
184
+ super().__init__(
185
+ primary_hue={"name": "cardinal", "c50": "#F9E8E8", "c100": "#F0C9C9", "c200": "#E39B9B",
186
+ "c300": "#D66E6E", "c400": "#C94A4A", "c500": "#B82C2C", "c600": "#8C1515",
187
+ "c700": "#771212", "c800": "#620E0E", "c900": "#4D0A0A", "c950": "#380707"},
188
+ secondary_hue={"name": "cool_gray", "c50": "#F5F5F6", "c100": "#E6E7E8", "c200": "#CDCED0",
189
+ "c300": "#B3B5B8", "c400": "#9A9CA0", "c500": "#818388", "c600": "#4D4F53",
190
+ "c700": "#424448", "c800": "#36383A", "c900": "#2E2D29", "c950": "#1D1D1B"},
191
+ neutral_hue="gray",
192
+ radius_size=gr.themes.sizes.radius_sm,
193
+ font=[gr.themes.GoogleFont("Source Sans Pro"), "ui-sans-serif", "system-ui"]
194
+ )
195
 
196
+ # Use the Stanford theme
197
+ theme = StanfordTheme()
 
 
 
 
198
 
199
+ # Set up the Gradio app
200
+ with gr.Blocks(theme=theme, title="Stanford Soft Raccoon Chat with Dataset Collection") as demo:
201
+ conversation_id = gr.State("")
202
 
203
  with gr.Row():
204
  with gr.Column(scale=3):
205
+ chatbot = gr.Chatbot(
206
+ label="Soft Raccoon Chat",
207
+ avatar_images=(None, "🦝"),
208
+ height=600
 
 
 
 
209
  )
210
+
211
+ with gr.Row():
212
+ msg = gr.Textbox(
213
+ placeholder="Send a message...",
214
+ show_label=False,
215
+ container=False
216
+ )
217
+ submit_btn = gr.Button("Send", variant="primary")
218
+
219
+ with gr.Accordion("Generation Parameters", open=False):
220
+ temperature = gr.Slider(
221
+ minimum=0.1,
222
+ maximum=2.0,
223
+ value=0.7,
224
+ step=0.1,
225
+ label="Temperature"
226
+ )
227
+ top_p = gr.Slider(
228
+ minimum=0.1,
229
+ maximum=1.0,
230
+ value=0.9,
231
+ step=0.05,
232
+ label="Top-P"
233
+ )
234
 
235
  with gr.Column(scale=1):
236
  with gr.Group():
237
  gr.Markdown("### Dataset Controls")
238
+ save_button = gr.Button("Save conversations now", variant="secondary")
239
+ status_output = gr.Textbox(label="Save Status", interactive=False)
240
 
241
+ with gr.Row():
242
+ convo_count = gr.Number(label="Total Conversations", interactive=False)
243
+ next_save = gr.Textbox(label="Next Auto-Save", interactive=False)
 
244
 
245
+ last_save_time_display = gr.Textbox(label="Last Save Time", interactive=False)
246
+ dataset_name_display = gr.Textbox(label="Dataset Name", interactive=False)
247
+
248
+ refresh_btn = gr.Button("Refresh Stats")
249
 
250
  # Set up event handlers
251
+ submit_btn.click(
252
+ predict,
253
+ [msg, chatbot, temperature, top_p, conversation_id],
254
+ [chatbot, conversation_id],
255
+ api_name="chat"
256
+ )
257
+
258
+ msg.submit(
259
+ predict,
260
+ [msg, chatbot, temperature, top_p, conversation_id],
261
+ [chatbot, conversation_id],
262
+ api_name=False
263
+ )
264
+
265
+ save_button.click(
266
+ save_dataset_manually,
267
+ [],
268
+ [status_output]
269
+ )
270
+
271
+ def update_stats():
272
+ stats = get_stats()
273
+ return [
274
+ stats["conversation_count"],
275
+ stats["next_save"],
276
+ stats["last_save"],
277
+ stats["dataset_name"]
278
+ ]
279
+
280
+ refresh_btn.click(
281
+ update_stats,
282
+ [],
283
+ [convo_count, next_save, last_save_time_display, dataset_name_display]
284
+ )
285
 
286
+ # Auto-update stats every 30 seconds
287
+ gr.on(
288
+ [demo.load, gr.Timeout(30)],
289
+ update_stats,
290
+ [],
291
+ [convo_count, next_save, last_save_time_display, dataset_name_display]
292
+ )
293
 
294
+ # Ensure we save on shutdown using atexit
295
+ import atexit
296
+ atexit.register(save_to_dataset)
297
 
298
+ # Set up a function that will be called when the demo loads
299
+ def on_startup():
300
+ return update_stats()
301
 
302
+ demo.load(on_startup, [], [convo_count, next_save, last_save_time_display, dataset_name_display])
 
303
 
304
+ # Launch the app
305
  if __name__ == "__main__":
306
+ demo.launch(share=True)
307
+