Spaces:
Runtime error
Runtime error
import spaces | |
import os | |
import uuid | |
import time | |
import json | |
import torch | |
from datetime import datetime, timedelta | |
from threading import Thread | |
from pathlib import Path | |
# Gradio and HuggingFace imports | |
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
from datasets import Dataset | |
from huggingface_hub import HfApi, login | |
# Model configuration | |
checkpoint = "WillHeld/soft-raccoon" | |
# Set device based on availability | |
if torch.cuda.is_available(): | |
device = "cuda" | |
else: | |
device = "cpu" | |
print("CUDA not available, using CPU instead. This will be much slower.") | |
# Dataset configuration | |
DATASET_NAME = "WillHeld/soft-raccoon-conversations" # Change to your username | |
SAVE_INTERVAL_MINUTES = 5 # Save data every 5 minutes | |
last_save_time = datetime.now() | |
# Initialize model and tokenizer | |
print(f"Loading model from {checkpoint}...") | |
tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device) | |
# Data storage | |
conversations = [] | |
# Hugging Face authentication | |
# Uncomment this line to login with your token | |
# login(token=os.environ.get("HF_TOKEN")) | |
def save_to_dataset(): | |
"""Save the current conversations to a HuggingFace dataset""" | |
if not conversations: | |
return None, f"No conversations to save. Last attempt: {datetime.now().strftime('%H:%M:%S')}" | |
# Convert conversations to dataset format | |
dataset_dict = { | |
"conversation_id": [], | |
"timestamp": [], | |
"messages": [], | |
"metadata": [] | |
} | |
for conv in conversations: | |
dataset_dict["conversation_id"].append(conv["conversation_id"]) | |
dataset_dict["timestamp"].append(conv["timestamp"]) | |
dataset_dict["messages"].append(json.dumps(conv["messages"])) | |
dataset_dict["metadata"].append(json.dumps(conv["metadata"])) | |
# Create dataset | |
dataset = Dataset.from_dict(dataset_dict) | |
try: | |
# Push to hub | |
dataset.push_to_hub(DATASET_NAME) | |
status_msg = f"Successfully saved {len(conversations)} conversations to {DATASET_NAME}" | |
print(status_msg) | |
except Exception as e: | |
# Save locally as fallback | |
local_path = f"local_dataset_{datetime.now().strftime('%Y%m%d_%H%M%S')}" | |
dataset.save_to_disk(local_path) | |
status_msg = f"Error pushing to hub: {str(e)}. Saved locally to '{local_path}'" | |
print(status_msg) | |
return dataset, status_msg | |
def chat_model(message, history, temperature=0.7, top_p=0.9): | |
"""Chat function for use with ChatInterface""" | |
conversation_id = getattr(chat_model, "conversation_id", None) | |
if conversation_id is None: | |
conversation_id = str(uuid.uuid4()) | |
chat_model.conversation_id = conversation_id | |
# Format chat history for the model | |
formatted_history = [] | |
for human_msg, ai_msg in history: | |
formatted_history.append({"role": "user", "content": human_msg}) | |
if ai_msg: # Skip None values that might occur during streaming | |
formatted_history.append({"role": "assistant", "content": ai_msg}) | |
# Add the current message | |
formatted_history.append({"role": "user", "content": message}) | |
# Prepare input for the model | |
input_text = tokenizer.apply_chat_template( | |
formatted_history, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device) | |
# Set up streaming | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
# Generation parameters | |
generation_kwargs = { | |
"input_ids": inputs, | |
"max_new_tokens": 1024, | |
"temperature": float(temperature), | |
"top_p": float(top_p), | |
"do_sample": True, | |
"streamer": streamer, | |
} | |
# Generate in a separate thread | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
# Initialize response | |
partial_text = "" | |
# Yield partial text as it's generated | |
for new_text in streamer: | |
partial_text += new_text | |
yield partial_text | |
# Store conversation data in the global conversations list | |
formatted_history.append({"role": "assistant", "content": partial_text}) | |
# Find existing conversation or create new one | |
existing_conv = next((c for c in conversations if c["conversation_id"] == conversation_id), None) | |
# Update or create conversation record | |
current_time = datetime.now().isoformat() | |
if existing_conv: | |
# Update existing conversation | |
existing_conv["messages"] = formatted_history | |
existing_conv["metadata"]["last_updated"] = current_time | |
existing_conv["metadata"]["temperature"] = temperature | |
existing_conv["metadata"]["top_p"] = top_p | |
else: | |
# Create new conversation record | |
conversations.append({ | |
"conversation_id": conversation_id, | |
"timestamp": current_time, | |
"messages": formatted_history, | |
"metadata": { | |
"model": checkpoint, | |
"temperature": temperature, | |
"top_p": top_p, | |
"last_updated": current_time | |
} | |
}) | |
# Check if it's time to save based on elapsed time | |
global last_save_time | |
current_time_dt = datetime.now() | |
if current_time_dt - last_save_time > timedelta(minutes=SAVE_INTERVAL_MINUTES): | |
save_to_dataset() | |
last_save_time = current_time_dt | |
def save_dataset_manually(): | |
"""Manually trigger dataset save and return status""" | |
_, status = save_to_dataset() | |
return status | |
def get_stats(): | |
"""Get current stats about conversations and saving""" | |
mins_until_save = SAVE_INTERVAL_MINUTES - (datetime.now() - last_save_time).seconds // 60 | |
if mins_until_save < 0: | |
mins_until_save = 0 | |
return { | |
"conversation_count": len(conversations), | |
"next_save": f"In {mins_until_save} minutes", | |
"last_save": last_save_time.strftime('%H:%M:%S'), | |
"dataset_name": DATASET_NAME | |
} | |
# Create a Stanford theme | |
theme = gr.themes.Default( | |
primary_hue=gr.themes.utils.colors.red, | |
secondary_hue=gr.themes.utils.colors.gray, | |
neutral_hue=gr.themes.utils.colors.gray, | |
font=[gr.themes.GoogleFont("Source Sans Pro"), "ui-sans-serif", "system-ui"] | |
).set( | |
button_primary_background_fill="#8C1515", | |
button_primary_background_fill_hover="#771212", | |
button_primary_text_color="white", | |
slider_color="#8C1515", | |
block_title_text_color="#8C1515", | |
block_label_text_color="#4D4F53", | |
input_border_color_focus="#8C1515", | |
checkbox_background_color_selected="#8C1515", | |
checkbox_border_color_selected="#8C1515", | |
button_secondary_border_color="#4D4F53", | |
block_title_background_fill="#f5f5f5", | |
block_label_background_fill="#f9f9f9" | |
) | |
# Custom CSS | |
css = """ | |
.gradio-container { | |
font-family: 'Source Sans Pro', sans-serif !important; | |
} | |
.footer { | |
color: #4D4F53 !important; | |
font-size: 0.85em !important; | |
} | |
""" | |
# Set up the Gradio app with Blocks for more control | |
with gr.Blocks(theme=theme, title="Stanford Soft Raccoon Chat", css=css) as demo: | |
with gr.Row(): | |
with gr.Column(scale=3): | |
# Use ChatInterface for the main chat functionality | |
chatbot = gr.ChatInterface( | |
fn=chat_model, | |
chatbot=gr.Chatbot( | |
label="Soft Raccoon Chat", | |
avatar_images=(None, "🌲"), # Stanford tree emoji | |
height=600, | |
placeholder="<strong>Soft Raccoon AI Assistant</strong><br>Ask me anything!" | |
), | |
additional_inputs=[ | |
gr.Slider( | |
minimum=0.1, | |
maximum=2.0, | |
value=0.7, | |
step=0.1, | |
label="Temperature" | |
), | |
gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.9, | |
step=0.05, | |
label="Top-P" | |
) | |
], | |
title="Stanford Soft Raccoon Chat", | |
description="AI assistant powered by the Soft Raccoon language model", | |
examples=[ | |
"Tell me about Stanford University", | |
"How can I learn about artificial intelligence?", | |
"What's your favorite book?" | |
], | |
cache_examples=True, | |
retry_btn="Regenerate", | |
undo_btn="Undo", | |
clear_btn="Clear", | |
) | |
with gr.Column(scale=1): | |
with gr.Group(): | |
gr.Markdown("### Dataset Controls") | |
save_button = gr.Button("Save conversations now", variant="secondary") | |
status_output = gr.Textbox(label="Save Status", interactive=False) | |
with gr.Row(): | |
convo_count = gr.Number(label="Total Conversations", interactive=False) | |
next_save = gr.Textbox(label="Next Auto-Save", interactive=False) | |
last_save_time_display = gr.Textbox(label="Last Save Time", interactive=False) | |
dataset_name_display = gr.Textbox(label="Dataset Name", interactive=False) | |
refresh_btn = gr.Button("Refresh Stats") | |
# Set up event handlers | |
save_button.click( | |
save_dataset_manually, | |
[], | |
[status_output] | |
) | |
def update_stats(): | |
stats = get_stats() | |
return [ | |
stats["conversation_count"], | |
stats["next_save"], | |
stats["last_save"], | |
stats["dataset_name"] | |
] | |
refresh_btn.click( | |
update_stats, | |
[], | |
[convo_count, next_save, last_save_time_display, dataset_name_display] | |
) | |
# Auto-update stats every 30 seconds | |
demo.load( | |
update_stats, | |
[], | |
[convo_count, next_save, last_save_time_display, dataset_name_display], | |
every=30 # Refresh every 30 seconds | |
) | |
# Ensure we save on shutdown | |
import atexit | |
atexit.register(save_to_dataset) | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch(share=True) |