WillHeld's picture
Update app.py
b0dd995 verified
raw
history blame
6.92 kB
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import gradio as gr
from threading import Thread
from datetime import datetime, timedelta
from datasets import Dataset
from huggingface_hub import HfApi, login
import uuid
import os
import time
checkpoint = "WillHeld/soft-raccoon"
device = "cuda"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)
# Dataset configuration
DATASET_NAME = "WillHeld/soft-raccoon-conversations" # Change to your HF username
PUSH_TO_HUB = True # Set to False if you just want to save locally first
# Time-based storage settings
SAVE_INTERVAL_MINUTES = 5 # Save every 5 minutes
last_save_time = datetime.now()
# Initialize storage for conversations
conversations = []
# Login to Huggingface Hub (you'll need to set HF_TOKEN env var or use login())
# Uncomment the below line to login with your token
login(token=os.environ.get("HF_TOKEN"))
def save_to_dataset():
"""Save the current conversations to a HuggingFace dataset"""
if not conversations:
return None
# Convert conversations to dataset format
dataset_dict = {
"conversation_id": [],
"timestamp": [],
"messages": [],
"metadata": []
}
for conv in conversations:
dataset_dict["conversation_id"].append(conv["conversation_id"])
dataset_dict["timestamp"].append(conv["timestamp"])
dataset_dict["messages"].append(conv["messages"])
dataset_dict["metadata"].append(conv["metadata"])
# Create dataset
dataset = Dataset.from_dict(dataset_dict)
if PUSH_TO_HUB:
try:
# Push to hub - will create the dataset if it doesn't exist
dataset.push_to_hub(DATASET_NAME)
print(f"Successfully pushed {len(conversations)} conversations to {DATASET_NAME}")
except Exception as e:
print(f"Error pushing to hub: {e}")
# Save locally as fallback
dataset.save_to_disk("local_dataset")
else:
# Save locally
dataset.save_to_disk("local_dataset")
print(f"Saved {len(conversations)} conversations locally to 'local_dataset'")
return dataset
@spaces.GPU(duration=120)
def predict(message, history, temperature, top_p, conversation_id=None):
# Create or retrieve conversation ID for tracking
if conversation_id is None:
conversation_id = str(uuid.uuid4())
# Update history with user message
history.append({"role": "user", "content": message})
input_text = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True)
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
# Create a streamer
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# Set up generation parameters
generation_kwargs = {
"input_ids": inputs,
"max_new_tokens": 1024,
"temperature": float(temperature),
"top_p": float(top_p),
"do_sample": True,
"streamer": streamer,
}
# Run generation in a separate thread
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Yield from the streamer as tokens are generated
partial_text = ""
for new_text in streamer:
partial_text += new_text
yield partial_text
# After generation completes, update history with assistant response
history.append({"role": "assistant", "content": partial_text})
# Store conversation data
# Check if we already have this conversation
existing_conv = next((c for c in conversations if c["conversation_id"] == conversation_id), None)
if existing_conv:
# Update existing conversation
existing_conv["messages"] = history
existing_conv["metadata"]["last_updated"] = datetime.now().isoformat()
else:
# Create new conversation record
conversations.append({
"conversation_id": conversation_id,
"timestamp": datetime.now().isoformat(),
"messages": history,
"metadata": {
"model": checkpoint,
"temperature": temperature,
"top_p": top_p,
"last_updated": datetime.now().isoformat()
}
})
# Check if it's time to save based on elapsed time
global last_save_time
current_time = datetime.now()
if current_time - last_save_time > timedelta(minutes=SAVE_INTERVAL_MINUTES):
save_to_dataset()
last_save_time = current_time
return partial_text
def save_dataset_button():
"""Manually save the current dataset"""
dataset = save_to_dataset()
if dataset:
return f"Saved {len(conversations)} conversations to dataset."
return "No conversations to save."
with gr.Blocks() as demo:
conversation_id = gr.State(None)
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.ChatInterface(
predict,
additional_inputs=[
gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P"),
conversation_id
],
type="messages"
)
with gr.Column(scale=1):
with gr.Group():
gr.Markdown("### Dataset Controls")
save_button = gr.Button("Save conversations to dataset")
save_output = gr.Textbox(label="Save Status")
# Display current conversation count
conversation_count = gr.Number(value=lambda: len(conversations),
label="Total Conversations",
interactive=False)
# Display time until next auto-save
next_save_time = gr.Textbox(label="Next Auto-Save",
value=lambda: f"In {SAVE_INTERVAL_MINUTES - (datetime.now() - last_save_time).seconds // 60} minutes")
refresh_button = gr.Button("Refresh Stats")
# Set up event handlers
save_button.click(save_dataset_button, outputs=save_output)
def refresh_stats():
mins_until_save = SAVE_INTERVAL_MINUTES - (datetime.now() - last_save_time).seconds // 60
return len(conversations), f"In {mins_until_save} minutes"
refresh_button.click(refresh_stats, outputs=[conversation_count, next_save_time])
# Save on shutdown
demo.on_close(save_to_dataset)
# Set up periodic UI refresh (every 60 seconds)
gr.Timer(60, lambda: None).start()
if __name__ == "__main__":
demo.launch()