WillHeld's picture
Update app.py
752950e verified
raw
history blame
10.4 kB
import os
import uuid
import time
import json
from datetime import datetime, timedelta
from threading import Thread
# Gradio and HuggingFace imports
import gradio as gr
from gradio.themes import Base
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from datasets import Dataset
from huggingface_hub import HfApi, login
# Model configuration
checkpoint = "WillHeld/soft-raccoon"
device = "cuda"
# Dataset configuration
DATASET_NAME = "your-username/soft-raccoon-conversations" # Change to your username
SAVE_INTERVAL_MINUTES = 5 # Save data every 5 minutes
last_save_time = datetime.now()
# Initialize model and tokenizer
print(f"Loading model from {checkpoint}...")
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)
# Data storage
conversations = []
# Hugging Face authentication
# Uncomment this line to login with your token
# login(token=os.environ.get("HF_TOKEN"))
def save_to_dataset():
"""Save the current conversations to a HuggingFace dataset"""
if not conversations:
return None, f"No conversations to save. Last attempt: {datetime.now().strftime('%H:%M:%S')}"
# Convert conversations to dataset format
dataset_dict = {
"conversation_id": [],
"timestamp": [],
"messages": [],
"metadata": []
}
for conv in conversations:
dataset_dict["conversation_id"].append(conv["conversation_id"])
dataset_dict["timestamp"].append(conv["timestamp"])
dataset_dict["messages"].append(json.dumps(conv["messages"]))
dataset_dict["metadata"].append(json.dumps(conv["metadata"]))
# Create dataset
dataset = Dataset.from_dict(dataset_dict)
try:
# Push to hub
dataset.push_to_hub(DATASET_NAME)
status_msg = f"Successfully saved {len(conversations)} conversations to {DATASET_NAME}"
print(status_msg)
except Exception as e:
# Save locally as fallback
local_path = f"local_dataset_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
dataset.save_to_disk(local_path)
status_msg = f"Error pushing to hub: {str(e)}. Saved locally to '{local_path}'"
print(status_msg)
return dataset, status_msg
def predict(message, chat_history, temperature, top_p, conversation_id=None):
"""Generate a response using the model and save the conversation"""
# Create/retrieve conversation ID for tracking
if conversation_id is None or conversation_id == "":
conversation_id = str(uuid.uuid4())
# Format chat history for the model
formatted_history = []
for human_msg, ai_msg in chat_history:
formatted_history.append({"role": "user", "content": human_msg})
if ai_msg: # Skip None values that might occur during streaming
formatted_history.append({"role": "assistant", "content": ai_msg})
# Add the current message
formatted_history.append({"role": "user", "content": message})
# Prepare input for the model
input_text = tokenizer.apply_chat_template(
formatted_history,
tokenize=False,
add_generation_prompt=True
)
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
# Set up streaming
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# Generation parameters
generation_kwargs = {
"input_ids": inputs,
"max_new_tokens": 1024,
"temperature": float(temperature),
"top_p": float(top_p),
"do_sample": True,
"streamer": streamer,
}
# Generate in a separate thread
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Initialize response
partial_text = ""
# Yield partial text as it's generated
for new_text in streamer:
partial_text += new_text
yield chat_history + [[message, partial_text]], conversation_id
# Store conversation data
existing_conv = next((c for c in conversations if c["conversation_id"] == conversation_id), None)
# Update history with final response
formatted_history.append({"role": "assistant", "content": partial_text})
# Update or create conversation record
current_time = datetime.now().isoformat()
if existing_conv:
# Update existing conversation
existing_conv["messages"] = formatted_history
existing_conv["metadata"]["last_updated"] = current_time
existing_conv["metadata"]["temperature"] = temperature
existing_conv["metadata"]["top_p"] = top_p
else:
# Create new conversation record
conversations.append({
"conversation_id": conversation_id,
"timestamp": current_time,
"messages": formatted_history,
"metadata": {
"model": checkpoint,
"temperature": temperature,
"top_p": top_p,
"last_updated": current_time
}
})
# Check if it's time to save based on elapsed time
global last_save_time
current_time_dt = datetime.now()
if current_time_dt - last_save_time > timedelta(minutes=SAVE_INTERVAL_MINUTES):
save_to_dataset()
last_save_time = current_time_dt
return chat_history + [[message, partial_text]], conversation_id
def save_dataset_manually():
"""Manually trigger dataset save"""
_, status = save_to_dataset()
return status
def get_stats():
"""Get current stats about conversations and saving"""
mins_until_save = SAVE_INTERVAL_MINUTES - (datetime.now() - last_save_time).seconds // 60
if mins_until_save < 0:
mins_until_save = 0
return {
"conversation_count": len(conversations),
"next_save": f"In {mins_until_save} minutes",
"last_save": last_save_time.strftime('%H:%M:%S'),
"dataset_name": DATASET_NAME
}
# Create a custom Stanford theme
class StanfordTheme(gr.Theme):
def __init__(self):
super().__init__(
primary_hue={"name": "cardinal", "c50": "#F9E8E8", "c100": "#F0C9C9", "c200": "#E39B9B",
"c300": "#D66E6E", "c400": "#C94A4A", "c500": "#B82C2C", "c600": "#8C1515",
"c700": "#771212", "c800": "#620E0E", "c900": "#4D0A0A", "c950": "#380707"},
secondary_hue={"name": "cool_gray", "c50": "#F5F5F6", "c100": "#E6E7E8", "c200": "#CDCED0",
"c300": "#B3B5B8", "c400": "#9A9CA0", "c500": "#818388", "c600": "#4D4F53",
"c700": "#424448", "c800": "#36383A", "c900": "#2E2D29", "c950": "#1D1D1B"},
neutral_hue="gray",
radius_size=gr.themes.sizes.radius_sm,
font=[gr.themes.GoogleFont("Source Sans Pro"), "ui-sans-serif", "system-ui"]
)
# Use the Stanford theme
theme = StanfordTheme()
# Set up the Gradio app
with gr.Blocks(theme=theme, title="Stanford Soft Raccoon Chat with Dataset Collection") as demo:
conversation_id = gr.State("")
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot(
label="Soft Raccoon Chat",
avatar_images=(None, "🦝"),
height=600
)
with gr.Row():
msg = gr.Textbox(
placeholder="Send a message...",
show_label=False,
container=False
)
submit_btn = gr.Button("Send", variant="primary")
with gr.Accordion("Generation Parameters", open=False):
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.7,
step=0.1,
label="Temperature"
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.05,
label="Top-P"
)
with gr.Column(scale=1):
with gr.Group():
gr.Markdown("### Dataset Controls")
save_button = gr.Button("Save conversations now", variant="secondary")
status_output = gr.Textbox(label="Save Status", interactive=False)
with gr.Row():
convo_count = gr.Number(label="Total Conversations", interactive=False)
next_save = gr.Textbox(label="Next Auto-Save", interactive=False)
last_save_time_display = gr.Textbox(label="Last Save Time", interactive=False)
dataset_name_display = gr.Textbox(label="Dataset Name", interactive=False)
refresh_btn = gr.Button("Refresh Stats")
# Set up event handlers
submit_btn.click(
predict,
[msg, chatbot, temperature, top_p, conversation_id],
[chatbot, conversation_id],
api_name="chat"
)
msg.submit(
predict,
[msg, chatbot, temperature, top_p, conversation_id],
[chatbot, conversation_id],
api_name=False
)
save_button.click(
save_dataset_manually,
[],
[status_output]
)
def update_stats():
stats = get_stats()
return [
stats["conversation_count"],
stats["next_save"],
stats["last_save"],
stats["dataset_name"]
]
refresh_btn.click(
update_stats,
[],
[convo_count, next_save, last_save_time_display, dataset_name_display]
)
# Auto-update stats every 30 seconds
gr.on(
[demo.load, gr.Timeout(30)],
update_stats,
[],
[convo_count, next_save, last_save_time_display, dataset_name_display]
)
# Ensure we save on shutdown using atexit
import atexit
atexit.register(save_to_dataset)
# Set up a function that will be called when the demo loads
def on_startup():
return update_stats()
demo.load(on_startup, [], [convo_count, next_save, last_save_time_display, dataset_name_display])
# Launch the app
if __name__ == "__main__":
demo.launch(share=True)