WillHeld commited on
Commit
6a10fd5
·
verified ·
1 Parent(s): 752950e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -5
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import uuid
3
  import time
@@ -14,7 +15,12 @@ from huggingface_hub import HfApi, login
14
 
15
  # Model configuration
16
  checkpoint = "WillHeld/soft-raccoon"
17
- device = "cuda"
 
 
 
 
 
18
 
19
  # Dataset configuration
20
  DATASET_NAME = "your-username/soft-raccoon-conversations" # Change to your username
@@ -71,6 +77,7 @@ def save_to_dataset():
71
  return dataset, status_msg
72
 
73
 
 
74
  def predict(message, chat_history, temperature, top_p, conversation_id=None):
75
  """Generate a response using the model and save the conversation"""
76
  # Create/retrieve conversation ID for tracking
@@ -203,8 +210,8 @@ with gr.Blocks(theme=theme, title="Stanford Soft Raccoon Chat with Dataset Colle
203
  with gr.Row():
204
  with gr.Column(scale=3):
205
  chatbot = gr.Chatbot(
206
- label="Soft Raccoon Chat",
207
- avatar_images=(None, "🦝"),
208
  height=600
209
  )
210
 
@@ -303,5 +310,4 @@ with gr.Blocks(theme=theme, title="Stanford Soft Raccoon Chat with Dataset Colle
303
 
304
  # Launch the app
305
  if __name__ == "__main__":
306
- demo.launch(share=True)
307
-
 
1
+ import torchimport spaces
2
  import os
3
  import uuid
4
  import time
 
15
 
16
  # Model configuration
17
  checkpoint = "WillHeld/soft-raccoon"
18
+ # Set device based on availability
19
+ if torch.cuda.is_available():
20
+ device = "cuda"
21
+ else:
22
+ device = "cpu"
23
+ print("CUDA not available, using CPU instead. This will be much slower.")
24
 
25
  # Dataset configuration
26
  DATASET_NAME = "your-username/soft-raccoon-conversations" # Change to your username
 
77
  return dataset, status_msg
78
 
79
 
80
+ @spaces.GPU(duration=120)
81
  def predict(message, chat_history, temperature, top_p, conversation_id=None):
82
  """Generate a response using the model and save the conversation"""
83
  # Create/retrieve conversation ID for tracking
 
210
  with gr.Row():
211
  with gr.Column(scale=3):
212
  chatbot = gr.Chatbot(
213
+ label="Stanford Soft Raccoon Chat",
214
+ avatar_images=(None, "🌲"), # Stanford tree emoji
215
  height=600
216
  )
217
 
 
310
 
311
  # Launch the app
312
  if __name__ == "__main__":
313
+ demo.launch(share=True)