File size: 6,919 Bytes
fddd482
b0dd995
a4b631b
b0dd995
 
 
 
 
 
 
a4b631b
b0dd995
fddd482
b29974e
 
a4b631b
b0dd995
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b29974e
b0dd995
 
 
 
 
 
b29974e
a60fda2
b0dd995
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b29974e
ab30e5f
b0dd995
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import gradio as gr
from threading import Thread
from datetime import datetime, timedelta
from datasets import Dataset
from huggingface_hub import HfApi, login
import uuid
import os
import time

checkpoint = "WillHeld/soft-raccoon"
device = "cuda"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)

# Dataset configuration
DATASET_NAME = "WillHeld/soft-raccoon-conversations"  # Change to your HF username
PUSH_TO_HUB = True  # Set to False if you just want to save locally first

# Time-based storage settings
SAVE_INTERVAL_MINUTES = 5  # Save every 5 minutes
last_save_time = datetime.now()

# Initialize storage for conversations
conversations = []

# Login to Huggingface Hub (you'll need to set HF_TOKEN env var or use login())
# Uncomment the below line to login with your token
login(token=os.environ.get("HF_TOKEN"))

def save_to_dataset():
    """Save the current conversations to a HuggingFace dataset"""
    if not conversations:
        return None
    
    # Convert conversations to dataset format
    dataset_dict = {
        "conversation_id": [],
        "timestamp": [],
        "messages": [],
        "metadata": []
    }
    
    for conv in conversations:
        dataset_dict["conversation_id"].append(conv["conversation_id"])
        dataset_dict["timestamp"].append(conv["timestamp"])
        dataset_dict["messages"].append(conv["messages"])
        dataset_dict["metadata"].append(conv["metadata"])
    
    # Create dataset
    dataset = Dataset.from_dict(dataset_dict)
    
    if PUSH_TO_HUB:
        try:
            # Push to hub - will create the dataset if it doesn't exist
            dataset.push_to_hub(DATASET_NAME)
            print(f"Successfully pushed {len(conversations)} conversations to {DATASET_NAME}")
        except Exception as e:
            print(f"Error pushing to hub: {e}")
            # Save locally as fallback
            dataset.save_to_disk("local_dataset")
    else:
        # Save locally
        dataset.save_to_disk("local_dataset")
        print(f"Saved {len(conversations)} conversations locally to 'local_dataset'")
    
    return dataset

@spaces.GPU(duration=120)
def predict(message, history, temperature, top_p, conversation_id=None):
    # Create or retrieve conversation ID for tracking
    if conversation_id is None:
        conversation_id = str(uuid.uuid4())
    
    # Update history with user message
    history.append({"role": "user", "content": message})
    input_text = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
    
    # Create a streamer
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    
    # Set up generation parameters
    generation_kwargs = {
        "input_ids": inputs,
        "max_new_tokens": 1024,
        "temperature": float(temperature),
        "top_p": float(top_p),
        "do_sample": True,
        "streamer": streamer,
    }
    
    # Run generation in a separate thread
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()
    
    # Yield from the streamer as tokens are generated
    partial_text = ""
    for new_text in streamer:
        partial_text += new_text
        yield partial_text
    
    # After generation completes, update history with assistant response
    history.append({"role": "assistant", "content": partial_text})
    
    # Store conversation data
    # Check if we already have this conversation
    existing_conv = next((c for c in conversations if c["conversation_id"] == conversation_id), None)
    
    if existing_conv:
        # Update existing conversation
        existing_conv["messages"] = history
        existing_conv["metadata"]["last_updated"] = datetime.now().isoformat()
    else:
        # Create new conversation record
        conversations.append({
            "conversation_id": conversation_id,
            "timestamp": datetime.now().isoformat(),
            "messages": history,
            "metadata": {
                "model": checkpoint,
                "temperature": temperature,
                "top_p": top_p,
                "last_updated": datetime.now().isoformat()
            }
        })
    
    # Check if it's time to save based on elapsed time
    global last_save_time
    current_time = datetime.now()
    if current_time - last_save_time > timedelta(minutes=SAVE_INTERVAL_MINUTES):
        save_to_dataset()
        last_save_time = current_time
    
    return partial_text

def save_dataset_button():
    """Manually save the current dataset"""
    dataset = save_to_dataset()
    if dataset:
        return f"Saved {len(conversations)} conversations to dataset."
    return "No conversations to save."

with gr.Blocks() as demo:
    conversation_id = gr.State(None)
    
    with gr.Row():
        with gr.Column(scale=3):
            chatbot = gr.ChatInterface(
                predict,
                additional_inputs=[
                    gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"),
                    gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P"),
                    conversation_id
                ],
                type="messages"
            )
        
        with gr.Column(scale=1):
            with gr.Group():
                gr.Markdown("### Dataset Controls")
                save_button = gr.Button("Save conversations to dataset")
                save_output = gr.Textbox(label="Save Status")
                
                # Display current conversation count
                conversation_count = gr.Number(value=lambda: len(conversations), 
                                              label="Total Conversations", 
                                              interactive=False)
                
                # Display time until next auto-save
                next_save_time = gr.Textbox(label="Next Auto-Save", 
                                           value=lambda: f"In {SAVE_INTERVAL_MINUTES - (datetime.now() - last_save_time).seconds // 60} minutes")
                refresh_button = gr.Button("Refresh Stats")
    
    # Set up event handlers
    save_button.click(save_dataset_button, outputs=save_output)
    
    def refresh_stats():
        mins_until_save = SAVE_INTERVAL_MINUTES - (datetime.now() - last_save_time).seconds // 60
        return len(conversations), f"In {mins_until_save} minutes"
    
    refresh_button.click(refresh_stats, outputs=[conversation_count, next_save_time])
    
    # Save on shutdown
    demo.on_close(save_to_dataset)
    
    # Set up periodic UI refresh (every 60 seconds)
    gr.Timer(60, lambda: None).start()

if __name__ == "__main__":
    demo.launch()