Spaces:
Runtime error
Runtime error
File size: 6,919 Bytes
fddd482 b0dd995 a4b631b b0dd995 a4b631b b0dd995 fddd482 b29974e a4b631b b0dd995 b29974e b0dd995 b29974e a60fda2 b0dd995 b29974e ab30e5f b0dd995 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import gradio as gr
from threading import Thread
from datetime import datetime, timedelta
from datasets import Dataset
from huggingface_hub import HfApi, login
import uuid
import os
import time
checkpoint = "WillHeld/soft-raccoon"
device = "cuda"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)
# Dataset configuration
DATASET_NAME = "WillHeld/soft-raccoon-conversations" # Change to your HF username
PUSH_TO_HUB = True # Set to False if you just want to save locally first
# Time-based storage settings
SAVE_INTERVAL_MINUTES = 5 # Save every 5 minutes
last_save_time = datetime.now()
# Initialize storage for conversations
conversations = []
# Login to Huggingface Hub (you'll need to set HF_TOKEN env var or use login())
# Uncomment the below line to login with your token
login(token=os.environ.get("HF_TOKEN"))
def save_to_dataset():
"""Save the current conversations to a HuggingFace dataset"""
if not conversations:
return None
# Convert conversations to dataset format
dataset_dict = {
"conversation_id": [],
"timestamp": [],
"messages": [],
"metadata": []
}
for conv in conversations:
dataset_dict["conversation_id"].append(conv["conversation_id"])
dataset_dict["timestamp"].append(conv["timestamp"])
dataset_dict["messages"].append(conv["messages"])
dataset_dict["metadata"].append(conv["metadata"])
# Create dataset
dataset = Dataset.from_dict(dataset_dict)
if PUSH_TO_HUB:
try:
# Push to hub - will create the dataset if it doesn't exist
dataset.push_to_hub(DATASET_NAME)
print(f"Successfully pushed {len(conversations)} conversations to {DATASET_NAME}")
except Exception as e:
print(f"Error pushing to hub: {e}")
# Save locally as fallback
dataset.save_to_disk("local_dataset")
else:
# Save locally
dataset.save_to_disk("local_dataset")
print(f"Saved {len(conversations)} conversations locally to 'local_dataset'")
return dataset
@spaces.GPU(duration=120)
def predict(message, history, temperature, top_p, conversation_id=None):
# Create or retrieve conversation ID for tracking
if conversation_id is None:
conversation_id = str(uuid.uuid4())
# Update history with user message
history.append({"role": "user", "content": message})
input_text = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True)
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
# Create a streamer
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# Set up generation parameters
generation_kwargs = {
"input_ids": inputs,
"max_new_tokens": 1024,
"temperature": float(temperature),
"top_p": float(top_p),
"do_sample": True,
"streamer": streamer,
}
# Run generation in a separate thread
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Yield from the streamer as tokens are generated
partial_text = ""
for new_text in streamer:
partial_text += new_text
yield partial_text
# After generation completes, update history with assistant response
history.append({"role": "assistant", "content": partial_text})
# Store conversation data
# Check if we already have this conversation
existing_conv = next((c for c in conversations if c["conversation_id"] == conversation_id), None)
if existing_conv:
# Update existing conversation
existing_conv["messages"] = history
existing_conv["metadata"]["last_updated"] = datetime.now().isoformat()
else:
# Create new conversation record
conversations.append({
"conversation_id": conversation_id,
"timestamp": datetime.now().isoformat(),
"messages": history,
"metadata": {
"model": checkpoint,
"temperature": temperature,
"top_p": top_p,
"last_updated": datetime.now().isoformat()
}
})
# Check if it's time to save based on elapsed time
global last_save_time
current_time = datetime.now()
if current_time - last_save_time > timedelta(minutes=SAVE_INTERVAL_MINUTES):
save_to_dataset()
last_save_time = current_time
return partial_text
def save_dataset_button():
"""Manually save the current dataset"""
dataset = save_to_dataset()
if dataset:
return f"Saved {len(conversations)} conversations to dataset."
return "No conversations to save."
with gr.Blocks() as demo:
conversation_id = gr.State(None)
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.ChatInterface(
predict,
additional_inputs=[
gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P"),
conversation_id
],
type="messages"
)
with gr.Column(scale=1):
with gr.Group():
gr.Markdown("### Dataset Controls")
save_button = gr.Button("Save conversations to dataset")
save_output = gr.Textbox(label="Save Status")
# Display current conversation count
conversation_count = gr.Number(value=lambda: len(conversations),
label="Total Conversations",
interactive=False)
# Display time until next auto-save
next_save_time = gr.Textbox(label="Next Auto-Save",
value=lambda: f"In {SAVE_INTERVAL_MINUTES - (datetime.now() - last_save_time).seconds // 60} minutes")
refresh_button = gr.Button("Refresh Stats")
# Set up event handlers
save_button.click(save_dataset_button, outputs=save_output)
def refresh_stats():
mins_until_save = SAVE_INTERVAL_MINUTES - (datetime.now() - last_save_time).seconds // 60
return len(conversations), f"In {mins_until_save} minutes"
refresh_button.click(refresh_stats, outputs=[conversation_count, next_save_time])
# Save on shutdown
demo.on_close(save_to_dataset)
# Set up periodic UI refresh (every 60 seconds)
gr.Timer(60, lambda: None).start()
if __name__ == "__main__":
demo.launch() |