Spaces:
Runtime error
Runtime error
import spaces | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
import gradio as gr | |
from threading import Thread | |
from datetime import datetime, timedelta | |
from datasets import Dataset | |
from huggingface_hub import HfApi, login | |
import uuid | |
import os | |
import time | |
checkpoint = "WillHeld/soft-raccoon" | |
device = "cuda" | |
tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device) | |
# Dataset configuration | |
DATASET_NAME = "WillHeld/soft-raccoon-conversations" # Change to your HF username | |
PUSH_TO_HUB = True # Set to False if you just want to save locally first | |
# Time-based storage settings | |
SAVE_INTERVAL_MINUTES = 5 # Save every 5 minutes | |
last_save_time = datetime.now() | |
# Initialize storage for conversations | |
conversations = [] | |
# Login to Huggingface Hub (you'll need to set HF_TOKEN env var or use login()) | |
# Uncomment the below line to login with your token | |
login(token=os.environ.get("HF_TOKEN")) | |
def save_to_dataset(): | |
"""Save the current conversations to a HuggingFace dataset""" | |
if not conversations: | |
return None | |
# Convert conversations to dataset format | |
dataset_dict = { | |
"conversation_id": [], | |
"timestamp": [], | |
"messages": [], | |
"metadata": [] | |
} | |
for conv in conversations: | |
dataset_dict["conversation_id"].append(conv["conversation_id"]) | |
dataset_dict["timestamp"].append(conv["timestamp"]) | |
dataset_dict["messages"].append(conv["messages"]) | |
dataset_dict["metadata"].append(conv["metadata"]) | |
# Create dataset | |
dataset = Dataset.from_dict(dataset_dict) | |
if PUSH_TO_HUB: | |
try: | |
# Push to hub - will create the dataset if it doesn't exist | |
dataset.push_to_hub(DATASET_NAME) | |
print(f"Successfully pushed {len(conversations)} conversations to {DATASET_NAME}") | |
except Exception as e: | |
print(f"Error pushing to hub: {e}") | |
# Save locally as fallback | |
dataset.save_to_disk("local_dataset") | |
else: | |
# Save locally | |
dataset.save_to_disk("local_dataset") | |
print(f"Saved {len(conversations)} conversations locally to 'local_dataset'") | |
return dataset | |
def predict(message, history, temperature, top_p, conversation_id=None): | |
# Create or retrieve conversation ID for tracking | |
if conversation_id is None: | |
conversation_id = str(uuid.uuid4()) | |
# Update history with user message | |
history.append({"role": "user", "content": message}) | |
input_text = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True) | |
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device) | |
# Create a streamer | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
# Set up generation parameters | |
generation_kwargs = { | |
"input_ids": inputs, | |
"max_new_tokens": 1024, | |
"temperature": float(temperature), | |
"top_p": float(top_p), | |
"do_sample": True, | |
"streamer": streamer, | |
} | |
# Run generation in a separate thread | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
# Yield from the streamer as tokens are generated | |
partial_text = "" | |
for new_text in streamer: | |
partial_text += new_text | |
yield partial_text | |
# After generation completes, update history with assistant response | |
history.append({"role": "assistant", "content": partial_text}) | |
# Store conversation data | |
# Check if we already have this conversation | |
existing_conv = next((c for c in conversations if c["conversation_id"] == conversation_id), None) | |
if existing_conv: | |
# Update existing conversation | |
existing_conv["messages"] = history | |
existing_conv["metadata"]["last_updated"] = datetime.now().isoformat() | |
else: | |
# Create new conversation record | |
conversations.append({ | |
"conversation_id": conversation_id, | |
"timestamp": datetime.now().isoformat(), | |
"messages": history, | |
"metadata": { | |
"model": checkpoint, | |
"temperature": temperature, | |
"top_p": top_p, | |
"last_updated": datetime.now().isoformat() | |
} | |
}) | |
# Check if it's time to save based on elapsed time | |
global last_save_time | |
current_time = datetime.now() | |
if current_time - last_save_time > timedelta(minutes=SAVE_INTERVAL_MINUTES): | |
save_to_dataset() | |
last_save_time = current_time | |
return partial_text | |
def save_dataset_button(): | |
"""Manually save the current dataset""" | |
dataset = save_to_dataset() | |
if dataset: | |
return f"Saved {len(conversations)} conversations to dataset." | |
return "No conversations to save." | |
with gr.Blocks() as demo: | |
conversation_id = gr.State(None) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
chatbot = gr.ChatInterface( | |
predict, | |
additional_inputs=[ | |
gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"), | |
gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P"), | |
conversation_id | |
], | |
type="messages" | |
) | |
with gr.Column(scale=1): | |
with gr.Group(): | |
gr.Markdown("### Dataset Controls") | |
save_button = gr.Button("Save conversations to dataset") | |
save_output = gr.Textbox(label="Save Status") | |
# Display current conversation count | |
conversation_count = gr.Number(value=lambda: len(conversations), | |
label="Total Conversations", | |
interactive=False) | |
# Display time until next auto-save | |
next_save_time = gr.Textbox(label="Next Auto-Save", | |
value=lambda: f"In {SAVE_INTERVAL_MINUTES - (datetime.now() - last_save_time).seconds // 60} minutes") | |
refresh_button = gr.Button("Refresh Stats") | |
# Set up event handlers | |
save_button.click(save_dataset_button, outputs=save_output) | |
def refresh_stats(): | |
mins_until_save = SAVE_INTERVAL_MINUTES - (datetime.now() - last_save_time).seconds // 60 | |
return len(conversations), f"In {mins_until_save} minutes" | |
refresh_button.click(refresh_stats, outputs=[conversation_count, next_save_time]) | |
# Save on shutdown | |
demo.on_close(save_to_dataset) | |
# Set up periodic UI refresh (every 60 seconds) | |
gr.Timer(60, lambda: None).start() | |
if __name__ == "__main__": | |
demo.launch() |