AbstractPhil commited on
Commit
6bede26
·
verified ·
1 Parent(s): fae170d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +147 -34
app.py CHANGED
@@ -6,8 +6,22 @@ from huggingface_hub import hf_hub_download
6
  from safetensors.torch import load_file as load_safetensors
7
 
8
  # ----------------------------
9
- # 🔧 Load Model and Tokenizer
10
  # ----------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  config = {
12
  "context": 512,
13
  "vocab_size": 8192,
@@ -21,7 +35,7 @@ config = {
21
  "repetition_penalty": 1.1,
22
  "presence_penalty": 0.6,
23
  "frequency_penalty": 0.0,
24
- "resid_dropout": 0.1, # Add these for model init
25
  "dropout": 0.0,
26
  "grad_checkpoint": False,
27
  "tokenizer_path": "beeper.tokenizer.json"
@@ -29,26 +43,68 @@ config = {
29
 
30
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
 
32
- # Load weights from Hugging Face repo
33
- repo_id = "AbstractPhil/beeper-rose-tinystories-6l-512d-ctx512"
34
- model_file = hf_hub_download(repo_id=repo_id, filename="beeper_rose_final.safetensors")
35
- tokenizer_file = hf_hub_download(repo_id=repo_id, filename="tokenizer.json")
36
-
37
- # Initialize model
38
- infer = BeeperRoseGPT(config).to(device)
39
 
40
- # Load safetensors properly
41
- state_dict = load_safetensors(model_file, device=str(device))
42
- infer.load_state_dict(state_dict)
43
- infer.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- # Load tokenizer
46
- tok = Tokenizer.from_file(tokenizer_file)
 
47
 
48
  # ----------------------------
49
  # 💬 Gradio Chat Wrapper
50
  # ----------------------------
51
- def beeper_reply(message, history, temperature=None, top_k=None, top_p=None):
 
 
 
 
 
 
 
 
 
 
 
 
52
  # Use defaults if not provided (for examples caching)
53
  if temperature is None:
54
  temperature = 0.9
@@ -98,24 +154,81 @@ def beeper_reply(message, history, temperature=None, top_k=None, top_p=None):
98
  # ----------------------------
99
  # 🖼️ Interface
100
  # ----------------------------
101
- demo = gr.ChatInterface(
102
- beeper_reply,
103
- additional_inputs=[
104
- gr.Slider(0.1, 1.5, value=0.9, step=0.1, label="Temperature"),
105
- gr.Slider(1, 100, value=40, step=1, label="Top-k"),
106
- gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p"),
107
- ],
108
- chatbot=gr.Chatbot(label="Chat with Beeper 🤖", type="messages"),
109
- title="Beeper - A Rose-based Tiny Language Model",
110
- description="Hello! I'm Beeper, a small language model trained with love and care. Please be patient with me - I'm still learning! 💕",
111
- examples=[
112
- ["Hello Beeper! How are you today?"],
113
- ["Can you tell me a story about a robot?"],
114
- ["What do you like to do for fun?"],
115
- ],
116
- theme=gr.themes.Soft(),
117
- cache_examples=False, # Disable caching to avoid the startup issue
118
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  if __name__ == "__main__":
121
  demo.launch()
 
6
  from safetensors.torch import load_file as load_safetensors
7
 
8
  # ----------------------------
9
+ # 🔧 Model versions configuration
10
  # ----------------------------
11
+ MODEL_VERSIONS = {
12
+ "Beeper v1 (Original)": {
13
+ "repo_id": "AbstractPhil/beeper-rose-tinystories-6l-512d-ctx512",
14
+ "model_file": "beeper_rose_final.safetensors",
15
+ "description": "Original Beeper trained on TinyStories"
16
+ },
17
+ "Beeper v2 (Extended)": {
18
+ "repo_id": "AbstractPhil/beeper-rose-v2",
19
+ "model_file": "beeper_final.safetensors",
20
+ "description": "Beeper v2 with extended training (~15 epochs)"
21
+ }
22
+ }
23
+
24
+ # Base configuration
25
  config = {
26
  "context": 512,
27
  "vocab_size": 8192,
 
35
  "repetition_penalty": 1.1,
36
  "presence_penalty": 0.6,
37
  "frequency_penalty": 0.0,
38
+ "resid_dropout": 0.1,
39
  "dropout": 0.0,
40
  "grad_checkpoint": False,
41
  "tokenizer_path": "beeper.tokenizer.json"
 
43
 
44
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
 
46
+ # Global model and tokenizer variables
47
+ infer = None
48
+ tok = None
49
+ current_version = None
 
 
 
50
 
51
+ def load_model_version(version_name):
52
+ """Load the selected model version"""
53
+ global infer, tok, current_version
54
+
55
+ if current_version == version_name and infer is not None:
56
+ return f"Already loaded: {version_name}"
57
+
58
+ version_info = MODEL_VERSIONS[version_name]
59
+
60
+ try:
61
+ # Download model and tokenizer files
62
+ model_file = hf_hub_download(
63
+ repo_id=version_info["repo_id"],
64
+ filename=version_info["model_file"]
65
+ )
66
+ tokenizer_file = hf_hub_download(
67
+ repo_id=version_info["repo_id"],
68
+ filename="tokenizer.json"
69
+ )
70
+
71
+ # Initialize model
72
+ infer = BeeperRoseGPT(config).to(device)
73
+
74
+ # Load safetensors
75
+ state_dict = load_safetensors(model_file, device=str(device))
76
+ infer.load_state_dict(state_dict)
77
+ infer.eval()
78
+
79
+ # Load tokenizer
80
+ tok = Tokenizer.from_file(tokenizer_file)
81
+
82
+ current_version = version_name
83
+ return f"Successfully loaded: {version_name}"
84
+
85
+ except Exception as e:
86
+ return f"Error loading {version_name}: {str(e)}"
87
 
88
+ # Load default model on startup
89
+ load_status = load_model_version("Beeper v1 (Original)")
90
+ print(load_status)
91
 
92
  # ----------------------------
93
  # 💬 Gradio Chat Wrapper
94
  # ----------------------------
95
+ def beeper_reply(message, history, model_version, temperature=None, top_k=None, top_p=None):
96
+ global infer, tok, current_version
97
+
98
+ # Load model if version changed
99
+ if model_version != current_version:
100
+ status = load_model_version(model_version)
101
+ if "Error" in status:
102
+ return f"⚠️ {status}"
103
+
104
+ # Check if model is loaded
105
+ if infer is None or tok is None:
106
+ return "⚠️ Model not loaded. Please select a version and try again."
107
+
108
  # Use defaults if not provided (for examples caching)
109
  if temperature is None:
110
  temperature = 0.9
 
154
  # ----------------------------
155
  # 🖼️ Interface
156
  # ----------------------------
157
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
158
+ gr.Markdown(
159
+ """
160
+ # 🤖 Beeper - A Rose-based Tiny Language Model
161
+ Hello! I'm Beeper, a small language model trained with love and care. Please be patient with me - I'm still learning! 💕
162
+ """
163
+ )
164
+
165
+ with gr.Row():
166
+ with gr.Column(scale=3):
167
+ model_dropdown = gr.Dropdown(
168
+ choices=list(MODEL_VERSIONS.keys()),
169
+ value="Beeper v1 (Original)",
170
+ label="Select Beeper Version",
171
+ info="Choose which version of Beeper to chat with"
172
+ )
173
+ with gr.Column(scale=7):
174
+ version_info = gr.Markdown("**Current:** Beeper v1 - Original training on TinyStories")
175
+
176
+ # Update version info when dropdown changes
177
+ def update_version_info(version_name):
178
+ info = MODEL_VERSIONS[version_name]["description"]
179
+ return f"**Current:** {info}"
180
+
181
+ model_dropdown.change(
182
+ fn=update_version_info,
183
+ inputs=[model_dropdown],
184
+ outputs=[version_info]
185
+ )
186
+
187
+ # Chat interface
188
+ chatbot = gr.Chatbot(label="Chat with Beeper", type="messages", height=400)
189
+ msg = gr.Textbox(label="Message", placeholder="Type your message here...")
190
+
191
+ with gr.Row():
192
+ with gr.Column(scale=2):
193
+ temperature_slider = gr.Slider(0.1, 1.5, value=0.9, step=0.1, label="Temperature")
194
+ with gr.Column(scale=2):
195
+ top_k_slider = gr.Slider(1, 100, value=40, step=1, label="Top-k")
196
+ with gr.Column(scale=2):
197
+ top_p_slider = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
198
+
199
+ with gr.Row():
200
+ submit = gr.Button("Send", variant="primary")
201
+ clear = gr.Button("Clear")
202
+
203
+ # Examples
204
+ gr.Examples(
205
+ examples=[
206
+ ["Hello Beeper! How are you today?"],
207
+ ["Can you tell me a story about a robot?"],
208
+ ["What do you like to do for fun?"],
209
+ ["What makes you happy?"],
210
+ ["Tell me about your dreams"],
211
+ ],
212
+ inputs=msg
213
+ )
214
+
215
+ # Handle chat
216
+ def respond(message, chat_history, model_version, temperature, top_k, top_p):
217
+ response = beeper_reply(message, chat_history, model_version, temperature, top_k, top_p)
218
+ chat_history.append([message, response])
219
+ return "", chat_history
220
+
221
+ msg.submit(
222
+ respond,
223
+ [msg, chatbot, model_dropdown, temperature_slider, top_k_slider, top_p_slider],
224
+ [msg, chatbot]
225
+ )
226
+ submit.click(
227
+ respond,
228
+ [msg, chatbot, model_dropdown, temperature_slider, top_k_slider, top_p_slider],
229
+ [msg, chatbot]
230
+ )
231
+ clear.click(lambda: None, None, chatbot, queue=False)
232
 
233
  if __name__ == "__main__":
234
  demo.launch()