Spaces:
Runtime error
Runtime error
import os | |
import uuid | |
import time | |
import json | |
from datetime import datetime, timedelta | |
from threading import Thread | |
# Gradio and HuggingFace imports | |
import gradio as gr | |
from gradio.themes import Base | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
from datasets import Dataset | |
from huggingface_hub import HfApi, login | |
# Model configuration | |
checkpoint = "WillHeld/soft-raccoon" | |
device = "cuda" | |
# Dataset configuration | |
DATASET_NAME = "your-username/soft-raccoon-conversations" # Change to your username | |
SAVE_INTERVAL_MINUTES = 5 # Save data every 5 minutes | |
last_save_time = datetime.now() | |
# Initialize model and tokenizer | |
print(f"Loading model from {checkpoint}...") | |
tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device) | |
# Data storage | |
conversations = [] | |
# Hugging Face authentication | |
# Uncomment this line to login with your token | |
# login(token=os.environ.get("HF_TOKEN")) | |
def save_to_dataset(): | |
"""Save the current conversations to a HuggingFace dataset""" | |
if not conversations: | |
return None, f"No conversations to save. Last attempt: {datetime.now().strftime('%H:%M:%S')}" | |
# Convert conversations to dataset format | |
dataset_dict = { | |
"conversation_id": [], | |
"timestamp": [], | |
"messages": [], | |
"metadata": [] | |
} | |
for conv in conversations: | |
dataset_dict["conversation_id"].append(conv["conversation_id"]) | |
dataset_dict["timestamp"].append(conv["timestamp"]) | |
dataset_dict["messages"].append(json.dumps(conv["messages"])) | |
dataset_dict["metadata"].append(json.dumps(conv["metadata"])) | |
# Create dataset | |
dataset = Dataset.from_dict(dataset_dict) | |
try: | |
# Push to hub | |
dataset.push_to_hub(DATASET_NAME) | |
status_msg = f"Successfully saved {len(conversations)} conversations to {DATASET_NAME}" | |
print(status_msg) | |
except Exception as e: | |
# Save locally as fallback | |
local_path = f"local_dataset_{datetime.now().strftime('%Y%m%d_%H%M%S')}" | |
dataset.save_to_disk(local_path) | |
status_msg = f"Error pushing to hub: {str(e)}. Saved locally to '{local_path}'" | |
print(status_msg) | |
return dataset, status_msg | |
def predict(message, chat_history, temperature, top_p, conversation_id=None): | |
"""Generate a response using the model and save the conversation""" | |
# Create/retrieve conversation ID for tracking | |
if conversation_id is None or conversation_id == "": | |
conversation_id = str(uuid.uuid4()) | |
# Format chat history for the model | |
formatted_history = [] | |
for human_msg, ai_msg in chat_history: | |
formatted_history.append({"role": "user", "content": human_msg}) | |
if ai_msg: # Skip None values that might occur during streaming | |
formatted_history.append({"role": "assistant", "content": ai_msg}) | |
# Add the current message | |
formatted_history.append({"role": "user", "content": message}) | |
# Prepare input for the model | |
input_text = tokenizer.apply_chat_template( | |
formatted_history, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device) | |
# Set up streaming | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
# Generation parameters | |
generation_kwargs = { | |
"input_ids": inputs, | |
"max_new_tokens": 1024, | |
"temperature": float(temperature), | |
"top_p": float(top_p), | |
"do_sample": True, | |
"streamer": streamer, | |
} | |
# Generate in a separate thread | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
# Initialize response | |
partial_text = "" | |
# Yield partial text as it's generated | |
for new_text in streamer: | |
partial_text += new_text | |
yield chat_history + [[message, partial_text]], conversation_id | |
# Store conversation data | |
existing_conv = next((c for c in conversations if c["conversation_id"] == conversation_id), None) | |
# Update history with final response | |
formatted_history.append({"role": "assistant", "content": partial_text}) | |
# Update or create conversation record | |
current_time = datetime.now().isoformat() | |
if existing_conv: | |
# Update existing conversation | |
existing_conv["messages"] = formatted_history | |
existing_conv["metadata"]["last_updated"] = current_time | |
existing_conv["metadata"]["temperature"] = temperature | |
existing_conv["metadata"]["top_p"] = top_p | |
else: | |
# Create new conversation record | |
conversations.append({ | |
"conversation_id": conversation_id, | |
"timestamp": current_time, | |
"messages": formatted_history, | |
"metadata": { | |
"model": checkpoint, | |
"temperature": temperature, | |
"top_p": top_p, | |
"last_updated": current_time | |
} | |
}) | |
# Check if it's time to save based on elapsed time | |
global last_save_time | |
current_time_dt = datetime.now() | |
if current_time_dt - last_save_time > timedelta(minutes=SAVE_INTERVAL_MINUTES): | |
save_to_dataset() | |
last_save_time = current_time_dt | |
return chat_history + [[message, partial_text]], conversation_id | |
def save_dataset_manually(): | |
"""Manually trigger dataset save""" | |
_, status = save_to_dataset() | |
return status | |
def get_stats(): | |
"""Get current stats about conversations and saving""" | |
mins_until_save = SAVE_INTERVAL_MINUTES - (datetime.now() - last_save_time).seconds // 60 | |
if mins_until_save < 0: | |
mins_until_save = 0 | |
return { | |
"conversation_count": len(conversations), | |
"next_save": f"In {mins_until_save} minutes", | |
"last_save": last_save_time.strftime('%H:%M:%S'), | |
"dataset_name": DATASET_NAME | |
} | |
# Create a custom Stanford theme | |
class StanfordTheme(gr.Theme): | |
def __init__(self): | |
super().__init__( | |
primary_hue={"name": "cardinal", "c50": "#F9E8E8", "c100": "#F0C9C9", "c200": "#E39B9B", | |
"c300": "#D66E6E", "c400": "#C94A4A", "c500": "#B82C2C", "c600": "#8C1515", | |
"c700": "#771212", "c800": "#620E0E", "c900": "#4D0A0A", "c950": "#380707"}, | |
secondary_hue={"name": "cool_gray", "c50": "#F5F5F6", "c100": "#E6E7E8", "c200": "#CDCED0", | |
"c300": "#B3B5B8", "c400": "#9A9CA0", "c500": "#818388", "c600": "#4D4F53", | |
"c700": "#424448", "c800": "#36383A", "c900": "#2E2D29", "c950": "#1D1D1B"}, | |
neutral_hue="gray", | |
radius_size=gr.themes.sizes.radius_sm, | |
font=[gr.themes.GoogleFont("Source Sans Pro"), "ui-sans-serif", "system-ui"] | |
) | |
# Use the Stanford theme | |
theme = StanfordTheme() | |
# Set up the Gradio app | |
with gr.Blocks(theme=theme, title="Stanford Soft Raccoon Chat with Dataset Collection") as demo: | |
conversation_id = gr.State("") | |
with gr.Row(): | |
with gr.Column(scale=3): | |
chatbot = gr.Chatbot( | |
label="Soft Raccoon Chat", | |
avatar_images=(None, "🦝"), | |
height=600 | |
) | |
with gr.Row(): | |
msg = gr.Textbox( | |
placeholder="Send a message...", | |
show_label=False, | |
container=False | |
) | |
submit_btn = gr.Button("Send", variant="primary") | |
with gr.Accordion("Generation Parameters", open=False): | |
temperature = gr.Slider( | |
minimum=0.1, | |
maximum=2.0, | |
value=0.7, | |
step=0.1, | |
label="Temperature" | |
) | |
top_p = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.9, | |
step=0.05, | |
label="Top-P" | |
) | |
with gr.Column(scale=1): | |
with gr.Group(): | |
gr.Markdown("### Dataset Controls") | |
save_button = gr.Button("Save conversations now", variant="secondary") | |
status_output = gr.Textbox(label="Save Status", interactive=False) | |
with gr.Row(): | |
convo_count = gr.Number(label="Total Conversations", interactive=False) | |
next_save = gr.Textbox(label="Next Auto-Save", interactive=False) | |
last_save_time_display = gr.Textbox(label="Last Save Time", interactive=False) | |
dataset_name_display = gr.Textbox(label="Dataset Name", interactive=False) | |
refresh_btn = gr.Button("Refresh Stats") | |
# Set up event handlers | |
submit_btn.click( | |
predict, | |
[msg, chatbot, temperature, top_p, conversation_id], | |
[chatbot, conversation_id], | |
api_name="chat" | |
) | |
msg.submit( | |
predict, | |
[msg, chatbot, temperature, top_p, conversation_id], | |
[chatbot, conversation_id], | |
api_name=False | |
) | |
save_button.click( | |
save_dataset_manually, | |
[], | |
[status_output] | |
) | |
def update_stats(): | |
stats = get_stats() | |
return [ | |
stats["conversation_count"], | |
stats["next_save"], | |
stats["last_save"], | |
stats["dataset_name"] | |
] | |
refresh_btn.click( | |
update_stats, | |
[], | |
[convo_count, next_save, last_save_time_display, dataset_name_display] | |
) | |
# Auto-update stats every 30 seconds | |
gr.on( | |
[demo.load, gr.Timeout(30)], | |
update_stats, | |
[], | |
[convo_count, next_save, last_save_time_display, dataset_name_display] | |
) | |
# Ensure we save on shutdown using atexit | |
import atexit | |
atexit.register(save_to_dataset) | |
# Set up a function that will be called when the demo loads | |
def on_startup(): | |
return update_stats() | |
demo.load(on_startup, [], [convo_count, next_save, last_save_time_display, dataset_name_display]) | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch(share=True) | |