ACMCMC commited on
Commit
f856f17
·
1 Parent(s): 0331d85

Gradio demo

Browse files
Files changed (4) hide show
  1. README.md +12 -2
  2. demo.py +318 -0
  3. main.ipynb +35 -891
  4. requirements.txt +5 -3
README.md CHANGED
@@ -1,4 +1,14 @@
1
- # ss-prompts
2
- Simon Says Prompts
 
 
 
 
 
 
 
 
 
 
3
 
4
  This repo is a demo of how to generate an embedding that gets fed into a given LLM and is optimized to make the LLM output a specific text under a greedy decoding scheme.
 
1
+ ---
2
+ title: "Simon Says Prompts"
3
+ emoji: "🗣️"
4
+ colorFrom: "blue"
5
+ colorTo: "pink"
6
+ sdk: "gradio"
7
+ sdk_version: "5.26.0"
8
+ app_file: "demo.py"
9
+ pinned: false
10
+ ---
11
+
12
+ # Simon Says Prompts
13
 
14
  This repo is a demo of how to generate an embedding that gets fed into a given LLM and is optimized to make the LLM output a specific text under a greedy decoding scheme.
demo.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import transformers
4
+ import time
5
+ from safetensors.torch import save_file, load_file
6
+ import tempfile
7
+ from io import BytesIO
8
+ import logging
9
+
10
+ # Load the tokenizer and model
11
+ tokenizer = transformers.AutoTokenizer.from_pretrained("openai-community/gpt2")
12
+ model = transformers.AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
13
+
14
+
15
+ # Update the optimization function in demo.py to align with the notebook
16
+
17
+
18
+ def optimize_simon_says_prompt(
19
+ input_text: str,
20
+ number_of_simon_says_tokens: int,
21
+ n_steps: int,
22
+ lr: float,
23
+ progress=gr.Progress(track_tqdm=False), # Gradio progress tracking
24
+ ) -> tuple[str, torch.Tensor]:
25
+ """
26
+ Optimize a Simon Says prompt based on the input text and display the optimization process.
27
+
28
+ Parameters:
29
+ input_text (str): The input text provided by the user.
30
+ number_of_simon_says_tokens (int): Number of Simon Says tokens to optimize.
31
+ n_steps (int): Number of optimization steps.
32
+ lr (float): Learning rate for the optimization process.
33
+ progress (gr.Progress): Gradio progress tracking.
34
+
35
+ Returns:
36
+ The optimized Simon Says prompt
37
+ """
38
+ # Tokenize the input text
39
+ tokens = tokenizer(
40
+ input_text,
41
+ return_tensors="pt",
42
+ padding=False,
43
+ truncation=True,
44
+ add_special_tokens=True,
45
+ )
46
+ embeddings = model.transformer.wte(tokens["input_ids"]).detach()
47
+
48
+ # Initialize a random Simon Says prompt
49
+ simon_says_prompt = torch.randn(
50
+ 1, number_of_simon_says_tokens, model.config.n_embd, requires_grad=True
51
+ )
52
+ optimizer = torch.optim.Adam([simon_says_prompt], lr=lr)
53
+ loss_fn = torch.nn.CrossEntropyLoss()
54
+
55
+ best_loss: float = float("inf")
56
+ best_simon_says_prompt: torch.Tensor = None
57
+
58
+ progress(0, desc="Starting optimization...")
59
+ time.sleep(1)
60
+
61
+ for step in range(n_steps):
62
+ optimizer.zero_grad()
63
+ expanded_prompt = torch.cat([simon_says_prompt, embeddings], dim=1)
64
+ logits = model(inputs_embeds=expanded_prompt).logits
65
+ probs = torch.softmax(logits[:, simon_says_prompt.size(-2) - 1 : -1], dim=-1)
66
+ ranks = (
67
+ torch.sum(
68
+ probs > probs.gather(2, tokens["input_ids"].unsqueeze(-1)), dim=-1
69
+ )
70
+ + 1
71
+ )
72
+ loss = loss_fn(
73
+ logits[:, simon_says_prompt.size(-2) - 1 : -1].reshape(-1, logits.size(-1)),
74
+ tokens["input_ids"].reshape(-1),
75
+ )
76
+ loss.backward()
77
+ optimizer.step()
78
+
79
+ avg_rank = ranks.float().mean().item()
80
+ progress(
81
+ step / n_steps,
82
+ desc=f"Step {step}, Loss: {loss.item():.4f}, Avg Rank: {avg_rank:.2f}, Max Rank: {ranks.max().item()}",
83
+ )
84
+
85
+ logging.info(
86
+ f"Step {step}, Loss: {loss.item():.4f}, Avg Rank: {avg_rank:.2f}, Max Rank: {ranks.max().item()}"
87
+ )
88
+
89
+ if loss.item() < best_loss:
90
+ best_loss = loss.item()
91
+ best_simon_says_prompt = simon_says_prompt.detach().clone()
92
+
93
+ # If all ranks are 1, stop the optimization (perfect prediction)
94
+ if torch.all(ranks == 1):
95
+ break
96
+
97
+ return best_simon_says_prompt
98
+
99
+
100
+ # Modify the download_tensor function to save the tensor as a safetensors file
101
+
102
+
103
+ def download_tensor(tensor):
104
+ """
105
+ Save a tensor to a safetensors file for download.
106
+
107
+ Parameters:
108
+ tensor (torch.Tensor): The tensor to be saved.
109
+
110
+ Returns:
111
+ str: The file path of the saved tensor.
112
+ """
113
+ file_path = "optimized_tensor.safetensors"
114
+ save_file({"optimized_tensor": tensor}, file_path)
115
+ return file_path
116
+
117
+
118
+ def upload_tensor(file):
119
+ """
120
+ Load a tensor from an uploaded safetensors file.
121
+
122
+ Parameters:
123
+ file (bytes): The uploaded file containing the safetensors data.
124
+
125
+ Returns:
126
+ torch.Tensor: The loaded tensor.
127
+
128
+ Raises:
129
+ ValueError: If the safetensors file is invalid or the header is too large.
130
+ """
131
+ if isinstance(file, bytes):
132
+ file = BytesIO(file) # Wrap bytes in a BytesIO object
133
+
134
+ with tempfile.NamedTemporaryFile(delete=True) as temp_file:
135
+ temp_file.write(
136
+ file.read()
137
+ ) # Directly write the BytesIO content to the temporary file
138
+ temp_file.flush()
139
+
140
+ try:
141
+ tensor_data = load_file(temp_file.name)
142
+ except Exception as e:
143
+ raise ValueError(f"Failed to load safetensors file: {e}")
144
+
145
+ if "optimized_tensor" not in tensor_data:
146
+ raise ValueError(
147
+ "The safetensors file does not contain the expected 'optimized_tensor' key."
148
+ )
149
+
150
+ return tensor_data["optimized_tensor"]
151
+
152
+
153
+ def greedy_decode_with_ss_prompt(
154
+ ss_prompt: torch.Tensor, progress=gr.Progress()
155
+ ) -> str:
156
+ """
157
+ Perform greedy decoding using an uploaded optimized tensor and input text.
158
+
159
+ Parameters:
160
+ ss_prompt (torch.Tensor): The uploaded optimized tensor.
161
+ progress (gr.Progress): Gradio progress tracking.
162
+
163
+ Returns:
164
+ str: The generated text.
165
+ """
166
+ generated_tokens = []
167
+ all_logits = []
168
+
169
+ progress(0, desc="Starting greedy decoding...")
170
+
171
+ with torch.no_grad():
172
+ for i in progress.tqdm(range(150), desc="Decoding..."):
173
+ if len(generated_tokens) == 0:
174
+ expanded_prompt = ss_prompt
175
+ else:
176
+ expanded_prompt = torch.cat(
177
+ [
178
+ ss_prompt,
179
+ model.transformer.wte(
180
+ torch.tensor(generated_tokens).unsqueeze(0)
181
+ ).detach(),
182
+ ],
183
+ dim=1,
184
+ )
185
+
186
+ logits = model(inputs_embeds=expanded_prompt).logits
187
+ next_token_logits = logits[0, -1, :]
188
+ next_token = next_token_logits.argmax().item()
189
+
190
+ logging.info(
191
+ f"Step {i}, Next Token: {next_token}, Logit: {next_token_logits[next_token].item()}"
192
+ )
193
+
194
+ generated_tokens.append(next_token)
195
+ all_logits.append(next_token_logits)
196
+
197
+ if next_token == tokenizer.eos_token_id:
198
+ break
199
+
200
+ generated_tokens = torch.tensor(generated_tokens)
201
+ generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
202
+
203
+ return generated_text
204
+
205
+
206
+ def process_and_generate(
207
+ input_text: str,
208
+ number_of_simon_says_tokens: int,
209
+ n_steps: int,
210
+ lr: float,
211
+ ) -> tuple[str, str]:
212
+ """
213
+ Optimize the Simon Says prompt, display the optimization process, and generate text based on the input text.
214
+
215
+ Parameters:
216
+ input_text (str): The input text provided by the user.
217
+ number_of_simon_says_tokens (int): Number of Simon Says tokens to optimize.
218
+ n_steps (int): Number of optimization steps.
219
+ lr (float): Learning rate for the optimization process.
220
+
221
+ Returns:
222
+ tuple: The optimized Simon Says prompt and the greedy-decoded text.
223
+ """
224
+ optimized_prompt = optimize_simon_says_prompt(
225
+ input_text=input_text,
226
+ number_of_simon_says_tokens=number_of_simon_says_tokens,
227
+ n_steps=n_steps,
228
+ lr=lr,
229
+ )
230
+
231
+ # Generate text using the optimized prompt
232
+ generated_text: str = greedy_decode_with_ss_prompt(optimized_prompt)
233
+
234
+ return (
235
+ generated_text,
236
+ download_tensor(optimized_prompt),
237
+ ) # Return the optimized tensor for download
238
+
239
+
240
+ def process_with_uploaded_tensor(
241
+ input_text: str, uploaded_tensor: torch.Tensor
242
+ ) -> tuple[str, str]:
243
+ """
244
+ Process the uploaded tensor and generate text based on the input text.
245
+
246
+ Parameters:
247
+ input_text (str): The input text provided by the user.
248
+ uploaded_tensor (torch.Tensor): The uploaded optimized tensor.
249
+
250
+ Returns:
251
+ tuple: The generated text and the file path of the uploaded tensor.
252
+ """
253
+ generated_text = greedy_decode_with_ss_prompt(uploaded_tensor)
254
+ return generated_text, None
255
+
256
+
257
+ theme = gr.themes.Soft(
258
+ primary_hue="fuchsia",
259
+ secondary_hue="cyan",
260
+ neutral_hue="gray",
261
+ radius_size="none",
262
+ font=[
263
+ gr.themes.GoogleFont("IBM Plex Sans"),
264
+ "ui-sans-serif",
265
+ "system-ui",
266
+ "sans-serif",
267
+ ],
268
+ font_mono=[
269
+ gr.themes.GoogleFont("IBM Plex Mono"),
270
+ "ui-monospace",
271
+ "Consolas",
272
+ "monospace",
273
+ ],
274
+ )
275
+
276
+ # Update the Gradio interface to include configurable parameters
277
+ demo = gr.Interface(
278
+ theme=theme,
279
+ title="Simon Says Prompt Optimization and Text Generation",
280
+ fn=lambda input_text, number_of_simon_says_tokens, n_steps, lr, uploaded_file: (
281
+ process_with_uploaded_tensor(input_text, upload_tensor(uploaded_file))
282
+ if uploaded_file
283
+ else process_and_generate(
284
+ input_text, number_of_simon_says_tokens, n_steps, lr
285
+ )
286
+ ),
287
+ inputs=[
288
+ gr.Textbox(
289
+ lines=5,
290
+ placeholder="Enter your text here...",
291
+ label="Input Text",
292
+ value="Hello world! I'm Aldan, happy to be here.",
293
+ ),
294
+ gr.Slider(
295
+ minimum=1, maximum=10, step=1, value=4, label="Number of Simon Says Tokens"
296
+ ),
297
+ gr.Slider(
298
+ minimum=100,
299
+ maximum=10000,
300
+ step=100,
301
+ value=5000,
302
+ label="Number of Optimization Steps",
303
+ ),
304
+ gr.Slider(
305
+ minimum=1e-5, maximum=1e-1, step=1e-5, value=1e-2, label="Learning Rate"
306
+ ),
307
+ gr.File(label="Upload Optimized Tensor (Optional)", type="binary"),
308
+ ],
309
+ outputs=[
310
+ gr.Textbox(label="Generated Text"),
311
+ gr.File(label="Download Optimized Tensor", type="filepath"),
312
+ ],
313
+ description="This demo optimizes a Simon Says prompt based on your input text, displays the optimization process, and generates text using the optimized prompt. Optionally, you can upload a pre-optimized tensor for inference.",
314
+ )
315
+
316
+ # Ensure the Gradio interface is correctly launched
317
+ if __name__ == "__main__":
318
+ demo.launch(debug=True, show_error=True)
main.ipynb CHANGED
@@ -2,25 +2,10 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 1,
6
  "id": "d1daab37",
7
  "metadata": {},
8
- "outputs": [
9
- {
10
- "data": {
11
- "application/vnd.jupyter.widget-view+json": {
12
- "model_id": "6be175a38e1843ab83831cb3a9a06296",
13
- "version_major": 2,
14
- "version_minor": 0
15
- },
16
- "text/plain": [
17
- "Textarea(value=\"Solomonoff's theory of inductive inference proposes that all problems of logical induction can…"
18
- ]
19
- },
20
- "metadata": {},
21
- "output_type": "display_data"
22
- }
23
- ],
24
  "source": [
25
  "import ipywidgets\n",
26
  "\n",
@@ -35,683 +20,10 @@
35
  },
36
  {
37
  "cell_type": "code",
38
- "execution_count": 6,
39
  "id": "09b3c097",
40
  "metadata": {},
41
- "outputs": [
42
- {
43
- "name": "stdout",
44
- "output_type": "stream",
45
- "text": [
46
- "Optimizing: Hey this is a test text of me, Aldan, writing some random text in this text box. Will this work? Perhaps!\n"
47
- ]
48
- },
49
- {
50
- "name": "stderr",
51
- "output_type": "stream",
52
- "text": [
53
- " 0%| | 2/5000 [00:00<13:45, 6.05it/s]"
54
- ]
55
- },
56
- {
57
- "name": "stdout",
58
- "output_type": "stream",
59
- "text": [
60
- "Step 0, Loss: 5.8886637687683105, L2 norm: 0.5542519688606262, avg rank: 655.2222290039062\n"
61
- ]
62
- },
63
- {
64
- "name": "stderr",
65
- "output_type": "stream",
66
- "text": [
67
- " 0%| | 12/5000 [00:01<10:57, 7.58it/s]"
68
- ]
69
- },
70
- {
71
- "name": "stdout",
72
- "output_type": "stream",
73
- "text": [
74
- "Step 10, Loss: 5.159086227416992, L2 norm: 3.46441388130188, avg rank: 439.6666564941406\n"
75
- ]
76
- },
77
- {
78
- "name": "stderr",
79
- "output_type": "stream",
80
- "text": [
81
- " 0%| | 22/5000 [00:03<12:31, 6.62it/s]"
82
- ]
83
- },
84
- {
85
- "name": "stdout",
86
- "output_type": "stream",
87
- "text": [
88
- "Step 20, Loss: 4.893388748168945, L2 norm: 5.287822246551514, avg rank: 380.629638671875\n"
89
- ]
90
- },
91
- {
92
- "name": "stderr",
93
- "output_type": "stream",
94
- "text": [
95
- " 1%| | 32/5000 [00:04<10:37, 7.79it/s]"
96
- ]
97
- },
98
- {
99
- "name": "stdout",
100
- "output_type": "stream",
101
- "text": [
102
- "Step 30, Loss: 4.576204299926758, L2 norm: 7.122143745422363, avg rank: 313.0740661621094\n"
103
- ]
104
- },
105
- {
106
- "name": "stderr",
107
- "output_type": "stream",
108
- "text": [
109
- " 1%| | 42/5000 [00:05<10:42, 7.71it/s]"
110
- ]
111
- },
112
- {
113
- "name": "stdout",
114
- "output_type": "stream",
115
- "text": [
116
- "Step 40, Loss: 4.330563545227051, L2 norm: 8.640031814575195, avg rank: 189.8518524169922\n"
117
- ]
118
- },
119
- {
120
- "name": "stderr",
121
- "output_type": "stream",
122
- "text": [
123
- " 1%| | 52/5000 [00:07<13:39, 6.04it/s]"
124
- ]
125
- },
126
- {
127
- "name": "stdout",
128
- "output_type": "stream",
129
- "text": [
130
- "Step 50, Loss: 4.107393264770508, L2 norm: 10.037625312805176, avg rank: 130.6666717529297\n"
131
- ]
132
- },
133
- {
134
- "name": "stderr",
135
- "output_type": "stream",
136
- "text": [
137
- " 1%| | 62/5000 [00:08<11:33, 7.12it/s]"
138
- ]
139
- },
140
- {
141
- "name": "stdout",
142
- "output_type": "stream",
143
- "text": [
144
- "Step 60, Loss: 3.8990049362182617, L2 norm: 11.52373218536377, avg rank: 105.9259262084961\n"
145
- ]
146
- },
147
- {
148
- "name": "stderr",
149
- "output_type": "stream",
150
- "text": [
151
- " 1%|▏ | 72/5000 [00:10<19:46, 4.15it/s]"
152
- ]
153
- },
154
- {
155
- "name": "stdout",
156
- "output_type": "stream",
157
- "text": [
158
- "Step 70, Loss: 3.689441442489624, L2 norm: 12.885846138000488, avg rank: 96.77777862548828\n"
159
- ]
160
- },
161
- {
162
- "name": "stderr",
163
- "output_type": "stream",
164
- "text": [
165
- " 2%|▏ | 82/5000 [00:13<14:27, 5.67it/s]"
166
- ]
167
- },
168
- {
169
- "name": "stdout",
170
- "output_type": "stream",
171
- "text": [
172
- "Step 80, Loss: 3.664461374282837, L2 norm: 14.095556259155273, avg rank: 82.48148345947266\n"
173
- ]
174
- },
175
- {
176
- "name": "stderr",
177
- "output_type": "stream",
178
- "text": [
179
- " 2%|▏ | 92/5000 [00:14<12:06, 6.76it/s]"
180
- ]
181
- },
182
- {
183
- "name": "stdout",
184
- "output_type": "stream",
185
- "text": [
186
- "Step 90, Loss: 3.359441041946411, L2 norm: 15.264603614807129, avg rank: 76.62963104248047\n"
187
- ]
188
- },
189
- {
190
- "name": "stderr",
191
- "output_type": "stream",
192
- "text": [
193
- " 2%|▏ | 102/5000 [00:15<11:25, 7.14it/s]"
194
- ]
195
- },
196
- {
197
- "name": "stdout",
198
- "output_type": "stream",
199
- "text": [
200
- "Step 100, Loss: 3.2641189098358154, L2 norm: 16.456912994384766, avg rank: 59.55555725097656\n"
201
- ]
202
- },
203
- {
204
- "name": "stderr",
205
- "output_type": "stream",
206
- "text": [
207
- " 2%|▏ | 111/5000 [00:17<10:21, 7.87it/s]"
208
- ]
209
- },
210
- {
211
- "name": "stdout",
212
- "output_type": "stream",
213
- "text": [
214
- "Step 110, Loss: 3.4184069633483887, L2 norm: 17.494909286499023, avg rank: 45.22222137451172\n"
215
- ]
216
- },
217
- {
218
- "name": "stderr",
219
- "output_type": "stream",
220
- "text": [
221
- " 2%|▏ | 122/5000 [00:19<14:21, 5.66it/s]"
222
- ]
223
- },
224
- {
225
- "name": "stdout",
226
- "output_type": "stream",
227
- "text": [
228
- "Step 120, Loss: 3.057454824447632, L2 norm: 18.361045837402344, avg rank: 47.592594146728516\n"
229
- ]
230
- },
231
- {
232
- "name": "stderr",
233
- "output_type": "stream",
234
- "text": [
235
- " 3%|▎ | 132/5000 [00:20<11:23, 7.12it/s]"
236
- ]
237
- },
238
- {
239
- "name": "stdout",
240
- "output_type": "stream",
241
- "text": [
242
- "Step 130, Loss: 2.752321243286133, L2 norm: 19.076860427856445, avg rank: 37.185184478759766\n"
243
- ]
244
- },
245
- {
246
- "name": "stderr",
247
- "output_type": "stream",
248
- "text": [
249
- " 3%|▎ | 142/5000 [00:21<10:12, 7.93it/s]"
250
- ]
251
- },
252
- {
253
- "name": "stdout",
254
- "output_type": "stream",
255
- "text": [
256
- "Step 140, Loss: 2.492600679397583, L2 norm: 19.873262405395508, avg rank: 29.0\n"
257
- ]
258
- },
259
- {
260
- "name": "stderr",
261
- "output_type": "stream",
262
- "text": [
263
- " 3%|▎ | 152/5000 [00:23<12:10, 6.64it/s]"
264
- ]
265
- },
266
- {
267
- "name": "stdout",
268
- "output_type": "stream",
269
- "text": [
270
- "Step 150, Loss: 2.2693941593170166, L2 norm: 20.689470291137695, avg rank: 25.22222137451172\n"
271
- ]
272
- },
273
- {
274
- "name": "stderr",
275
- "output_type": "stream",
276
- "text": [
277
- " 3%|▎ | 162/5000 [00:24<11:26, 7.05it/s]"
278
- ]
279
- },
280
- {
281
- "name": "stdout",
282
- "output_type": "stream",
283
- "text": [
284
- "Step 160, Loss: 2.0094830989837646, L2 norm: 21.427452087402344, avg rank: 20.185184478759766\n"
285
- ]
286
- },
287
- {
288
- "name": "stderr",
289
- "output_type": "stream",
290
- "text": [
291
- " 3%|▎ | 172/5000 [00:26<11:22, 7.08it/s]"
292
- ]
293
- },
294
- {
295
- "name": "stdout",
296
- "output_type": "stream",
297
- "text": [
298
- "Step 170, Loss: 1.8882372379302979, L2 norm: 22.13208770751953, avg rank: 23.074073791503906\n"
299
- ]
300
- },
301
- {
302
- "name": "stderr",
303
- "output_type": "stream",
304
- "text": [
305
- " 4%|▎ | 182/5000 [00:27<09:56, 8.08it/s]"
306
- ]
307
- },
308
- {
309
- "name": "stdout",
310
- "output_type": "stream",
311
- "text": [
312
- "Step 180, Loss: 1.8304041624069214, L2 norm: 22.583101272583008, avg rank: 28.33333396911621\n"
313
- ]
314
- },
315
- {
316
- "name": "stderr",
317
- "output_type": "stream",
318
- "text": [
319
- " 4%|▍ | 192/5000 [00:28<10:58, 7.30it/s]"
320
- ]
321
- },
322
- {
323
- "name": "stdout",
324
- "output_type": "stream",
325
- "text": [
326
- "Step 190, Loss: 1.5684080123901367, L2 norm: 23.05557632446289, avg rank: 22.037036895751953\n"
327
- ]
328
- },
329
- {
330
- "name": "stderr",
331
- "output_type": "stream",
332
- "text": [
333
- " 4%|▍ | 202/5000 [00:30<09:58, 8.02it/s]"
334
- ]
335
- },
336
- {
337
- "name": "stdout",
338
- "output_type": "stream",
339
- "text": [
340
- "Step 200, Loss: 1.3705590963363647, L2 norm: 23.487506866455078, avg rank: 18.925926208496094\n"
341
- ]
342
- },
343
- {
344
- "name": "stderr",
345
- "output_type": "stream",
346
- "text": [
347
- " 4%|▍ | 212/5000 [00:31<09:58, 8.00it/s]"
348
- ]
349
- },
350
- {
351
- "name": "stdout",
352
- "output_type": "stream",
353
- "text": [
354
- "Step 210, Loss: 1.2061578035354614, L2 norm: 23.888507843017578, avg rank: 15.481481552124023\n"
355
- ]
356
- },
357
- {
358
- "name": "stderr",
359
- "output_type": "stream",
360
- "text": [
361
- " 4%|▍ | 222/5000 [00:32<10:36, 7.50it/s]"
362
- ]
363
- },
364
- {
365
- "name": "stdout",
366
- "output_type": "stream",
367
- "text": [
368
- "Step 220, Loss: 1.0748673677444458, L2 norm: 24.25291633605957, avg rank: 13.333333015441895\n"
369
- ]
370
- },
371
- {
372
- "name": "stderr",
373
- "output_type": "stream",
374
- "text": [
375
- " 5%|▍ | 232/5000 [00:34<11:45, 6.76it/s]"
376
- ]
377
- },
378
- {
379
- "name": "stdout",
380
- "output_type": "stream",
381
- "text": [
382
- "Step 230, Loss: 0.9695132374763489, L2 norm: 24.588470458984375, avg rank: 12.592592239379883\n"
383
- ]
384
- },
385
- {
386
- "name": "stderr",
387
- "output_type": "stream",
388
- "text": [
389
- " 5%|▍ | 242/5000 [00:35<11:47, 6.73it/s]"
390
- ]
391
- },
392
- {
393
- "name": "stdout",
394
- "output_type": "stream",
395
- "text": [
396
- "Step 240, Loss: 0.8693955540657043, L2 norm: 24.909320831298828, avg rank: 11.518518447875977\n"
397
- ]
398
- },
399
- {
400
- "name": "stderr",
401
- "output_type": "stream",
402
- "text": [
403
- " 5%|▌ | 252/5000 [00:36<09:53, 7.99it/s]"
404
- ]
405
- },
406
- {
407
- "name": "stdout",
408
- "output_type": "stream",
409
- "text": [
410
- "Step 250, Loss: 0.7695979475975037, L2 norm: 25.205760955810547, avg rank: 10.037036895751953\n"
411
- ]
412
- },
413
- {
414
- "name": "stderr",
415
- "output_type": "stream",
416
- "text": [
417
- " 5%|▌ | 262/5000 [00:38<10:06, 7.81it/s]"
418
- ]
419
- },
420
- {
421
- "name": "stdout",
422
- "output_type": "stream",
423
- "text": [
424
- "Step 260, Loss: 0.6803638935089111, L2 norm: 25.493518829345703, avg rank: 8.666666984558105\n"
425
- ]
426
- },
427
- {
428
- "name": "stderr",
429
- "output_type": "stream",
430
- "text": [
431
- " 5%|▌ | 272/5000 [00:39<11:42, 6.73it/s]"
432
- ]
433
- },
434
- {
435
- "name": "stdout",
436
- "output_type": "stream",
437
- "text": [
438
- "Step 270, Loss: 0.6074316501617432, L2 norm: 25.7545108795166, avg rank: 7.407407283782959\n"
439
- ]
440
- },
441
- {
442
- "name": "stderr",
443
- "output_type": "stream",
444
- "text": [
445
- " 6%|▌ | 282/5000 [00:40<10:20, 7.60it/s]"
446
- ]
447
- },
448
- {
449
- "name": "stdout",
450
- "output_type": "stream",
451
- "text": [
452
- "Step 280, Loss: 0.5503782629966736, L2 norm: 25.987356185913086, avg rank: 6.666666507720947\n"
453
- ]
454
- },
455
- {
456
- "name": "stderr",
457
- "output_type": "stream",
458
- "text": [
459
- " 6%|▌ | 292/5000 [00:42<10:42, 7.33it/s]"
460
- ]
461
- },
462
- {
463
- "name": "stdout",
464
- "output_type": "stream",
465
- "text": [
466
- "Step 290, Loss: 0.5049920678138733, L2 norm: 26.197214126586914, avg rank: 6.111111164093018\n"
467
- ]
468
- },
469
- {
470
- "name": "stderr",
471
- "output_type": "stream",
472
- "text": [
473
- " 6%|▌ | 302/5000 [00:43<11:42, 6.68it/s]"
474
- ]
475
- },
476
- {
477
- "name": "stdout",
478
- "output_type": "stream",
479
- "text": [
480
- "Step 300, Loss: 0.46782854199409485, L2 norm: 26.38979721069336, avg rank: 5.407407283782959\n"
481
- ]
482
- },
483
- {
484
- "name": "stderr",
485
- "output_type": "stream",
486
- "text": [
487
- " 6%|▌ | 312/5000 [00:45<10:34, 7.39it/s]"
488
- ]
489
- },
490
- {
491
- "name": "stdout",
492
- "output_type": "stream",
493
- "text": [
494
- "Step 310, Loss: 0.4365224838256836, L2 norm: 26.568010330200195, avg rank: 4.666666507720947\n"
495
- ]
496
- },
497
- {
498
- "name": "stderr",
499
- "output_type": "stream",
500
- "text": [
501
- " 6%|▋ | 322/5000 [00:46<10:26, 7.46it/s]"
502
- ]
503
- },
504
- {
505
- "name": "stdout",
506
- "output_type": "stream",
507
- "text": [
508
- "Step 320, Loss: 0.4172615706920624, L2 norm: 26.73436737060547, avg rank: 4.296296119689941\n"
509
- ]
510
- },
511
- {
512
- "name": "stderr",
513
- "output_type": "stream",
514
- "text": [
515
- " 7%|▋ | 332/5000 [00:48<11:34, 6.73it/s]"
516
- ]
517
- },
518
- {
519
- "name": "stdout",
520
- "output_type": "stream",
521
- "text": [
522
- "Step 330, Loss: 0.44890207052230835, L2 norm: 26.890453338623047, avg rank: 4.777777671813965\n"
523
- ]
524
- },
525
- {
526
- "name": "stderr",
527
- "output_type": "stream",
528
- "text": [
529
- " 7%|▋ | 342/5000 [00:49<09:44, 7.97it/s]"
530
- ]
531
- },
532
- {
533
- "name": "stdout",
534
- "output_type": "stream",
535
- "text": [
536
- "Step 340, Loss: 1.6345257759094238, L2 norm: 27.1015625, avg rank: 10.0\n"
537
- ]
538
- },
539
- {
540
- "name": "stderr",
541
- "output_type": "stream",
542
- "text": [
543
- " 7%|▋ | 352/5000 [00:50<10:12, 7.59it/s]"
544
- ]
545
- },
546
- {
547
- "name": "stdout",
548
- "output_type": "stream",
549
- "text": [
550
- "Step 350, Loss: 1.079402208328247, L2 norm: 27.411144256591797, avg rank: 6.592592716217041\n"
551
- ]
552
- },
553
- {
554
- "name": "stderr",
555
- "output_type": "stream",
556
- "text": [
557
- " 7%|▋ | 362/5000 [00:52<12:10, 6.35it/s]"
558
- ]
559
- },
560
- {
561
- "name": "stdout",
562
- "output_type": "stream",
563
- "text": [
564
- "Step 360, Loss: 0.8217418193817139, L2 norm: 27.712562561035156, avg rank: 4.703703880310059\n"
565
- ]
566
- },
567
- {
568
- "name": "stderr",
569
- "output_type": "stream",
570
- "text": [
571
- " 7%|▋ | 372/5000 [00:53<11:04, 6.97it/s]"
572
- ]
573
- },
574
- {
575
- "name": "stdout",
576
- "output_type": "stream",
577
- "text": [
578
- "Step 370, Loss: 0.6470327973365784, L2 norm: 28.024415969848633, avg rank: 1.8888888359069824\n"
579
- ]
580
- },
581
- {
582
- "name": "stderr",
583
- "output_type": "stream",
584
- "text": [
585
- " 8%|▊ | 382/5000 [00:55<10:21, 7.43it/s]"
586
- ]
587
- },
588
- {
589
- "name": "stdout",
590
- "output_type": "stream",
591
- "text": [
592
- "Step 380, Loss: 1.0997235774993896, L2 norm: 28.2260684967041, avg rank: 7.185184955596924\n"
593
- ]
594
- },
595
- {
596
- "name": "stderr",
597
- "output_type": "stream",
598
- "text": [
599
- " 8%|▊ | 392/5000 [00:56<10:58, 7.00it/s]"
600
- ]
601
- },
602
- {
603
- "name": "stdout",
604
- "output_type": "stream",
605
- "text": [
606
- "Step 390, Loss: 0.64601731300354, L2 norm: 28.461565017700195, avg rank: 2.9259259700775146\n"
607
- ]
608
- },
609
- {
610
- "name": "stderr",
611
- "output_type": "stream",
612
- "text": [
613
- " 8%|▊ | 402/5000 [00:58<09:58, 7.68it/s]"
614
- ]
615
- },
616
- {
617
- "name": "stdout",
618
- "output_type": "stream",
619
- "text": [
620
- "Step 400, Loss: 0.482523649930954, L2 norm: 28.71732521057129, avg rank: 1.5185185670852661\n"
621
- ]
622
- },
623
- {
624
- "name": "stderr",
625
- "output_type": "stream",
626
- "text": [
627
- " 8%|▊ | 412/5000 [00:59<09:54, 7.71it/s]"
628
- ]
629
- },
630
- {
631
- "name": "stdout",
632
- "output_type": "stream",
633
- "text": [
634
- "Step 410, Loss: 0.38174617290496826, L2 norm: 28.943510055541992, avg rank: 1.2592592239379883\n"
635
- ]
636
- },
637
- {
638
- "name": "stderr",
639
- "output_type": "stream",
640
- "text": [
641
- " 8%|▊ | 422/5000 [01:01<12:41, 6.01it/s]"
642
- ]
643
- },
644
- {
645
- "name": "stdout",
646
- "output_type": "stream",
647
- "text": [
648
- "Step 420, Loss: 0.29940587282180786, L2 norm: 29.132537841796875, avg rank: 1.1111111640930176\n"
649
- ]
650
- },
651
- {
652
- "name": "stderr",
653
- "output_type": "stream",
654
- "text": [
655
- " 9%|▊ | 432/5000 [01:02<10:03, 7.57it/s]"
656
- ]
657
- },
658
- {
659
- "name": "stdout",
660
- "output_type": "stream",
661
- "text": [
662
- "Step 430, Loss: 0.23488281667232513, L2 norm: 29.293710708618164, avg rank: 1.0740740299224854\n"
663
- ]
664
- },
665
- {
666
- "name": "stderr",
667
- "output_type": "stream",
668
- "text": [
669
- " 9%|▉ | 442/5000 [01:03<09:35, 7.91it/s]"
670
- ]
671
- },
672
- {
673
- "name": "stdout",
674
- "output_type": "stream",
675
- "text": [
676
- "Step 440, Loss: 0.19830304384231567, L2 norm: 29.436723709106445, avg rank: 1.0370370149612427\n"
677
- ]
678
- },
679
- {
680
- "name": "stderr",
681
- "output_type": "stream",
682
- "text": [
683
- " 9%|▉ | 452/5000 [01:05<10:41, 7.09it/s]"
684
- ]
685
- },
686
- {
687
- "name": "stdout",
688
- "output_type": "stream",
689
- "text": [
690
- "Step 450, Loss: 0.17161428928375244, L2 norm: 29.54309844970703, avg rank: 1.0370370149612427\n"
691
- ]
692
- },
693
- {
694
- "name": "stderr",
695
- "output_type": "stream",
696
- "text": [
697
- " 9%|▉ | 454/5000 [01:05<10:54, 6.94it/s]"
698
- ]
699
- },
700
- {
701
- "name": "stdout",
702
- "output_type": "stream",
703
- "text": [
704
- "Perfect ranks achieved at step 454, stopping optimization.\n"
705
- ]
706
- },
707
- {
708
- "name": "stderr",
709
- "output_type": "stream",
710
- "text": [
711
- "\n"
712
- ]
713
- }
714
- ],
715
  "source": [
716
  "import transformers\n",
717
  "import torch\n",
@@ -738,30 +50,28 @@
738
  "\n",
739
  "embeddings = model.transformer.wte(tokens[\"input_ids\"]).detach()\n",
740
  "\n",
741
- "# We'll use a soft prompt - a special token and its hidden states for the first token that we'll use to condition the model. We'll optimize the hidden states of this token to maximize the likelihood of the text that follows it.\n",
742
- "# Generate a soft prompt by creating random hidden states for the first token\n",
743
  "# We'll optimize these hidden states to maximize the likelihood of the text that follows it\n",
744
  "# past_key_values (Tuple[Tuple[torch.Tensor]] of length config.n_layers) — Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see past_key_values output below). Can be used to speed up sequential decoding. The input_ids which have their past given to this model should not be passed as input_ids as they have already been computed.\n",
745
  "# Shape: past_key_values (Tuple[Tuple[torch.Tensor]] of length config.n_layers)\n",
746
  "# with each tuple having 2 tensors of shape (batch_size, num_heads, sequence_length, embed_size_per_head))\n",
747
  "# The precision of the hidden states should be fp16\n",
748
- "ss_prompt = (\n",
749
  " # One tensor of shape (1, 1, embed_size_per_head)\n",
750
  " torch.randn(\n",
751
  " 1,\n",
752
  " 4,\n",
753
  " model.config.n_embd,\n",
754
  " requires_grad=True,\n",
755
- " # dtype=torch.float16,\n",
756
  " )\n",
757
- " # for _ in range(model.config.n_layer)\n",
758
  ")\n",
759
  "\n",
760
- "# Copy the soft prompts since we'll optimize them and they'll change\n",
761
- "original_ss_prompt = tuple(key.detach().clone() for key in ss_prompt)\n",
762
  "\n",
763
  "# Define the optimizer\n",
764
- "optimizer = torch.optim.Adam([ss_prompt], lr=1e-2)\n",
765
  "\n",
766
  "# Define the loss function\n",
767
  "loss_fn = torch.nn.CrossEntropyLoss()\n",
@@ -775,8 +85,8 @@
775
  "# Freeze the model parameters\n",
776
  "model.eval()\n",
777
  "\n",
778
- "# Check that the soft prompt is optimizable (requires_grad=True)\n",
779
- "for key in ss_prompt:\n",
780
  " assert key.requires_grad\n",
781
  "\n",
782
  "# Disable gradient computation for the model\n",
@@ -784,20 +94,20 @@
784
  " param.requires_grad = False\n",
785
  "\n",
786
  "best_loss = float(\"inf\")\n",
787
- "best_ss_prompt = None\n",
788
  "\n",
789
- "# Optimize the soft prompt\n",
790
  "for step in tqdm.tqdm(range(n_steps)):\n",
791
  " # Zero the gradients\n",
792
  " optimizer.zero_grad()\n",
793
  "\n",
794
- " # Add the optimizable SS prompt to the embeddings\n",
795
- " expanded_prompt = torch.cat([ss_prompt, embeddings], dim=1)\n",
796
  "\n",
797
  " # Generate the logits for the text\n",
798
  " logits: torch.Tensor = model(inputs_embeds=expanded_prompt).logits\n",
799
  "\n",
800
- " probs = torch.softmax(logits[:, ss_prompt.size(-2) - 1 :-1], dim=-1)\n",
801
  "\n",
802
  " # Compute the ranks of the input IDs, i.e. how many tokens would have been more likely than the correct one (the label, the input IDs)\n",
803
  " \n",
@@ -806,21 +116,21 @@
806
  "\n",
807
  " # Compute the loss\n",
808
  " loss = loss_fn(\n",
809
- " logits[:, ss_prompt.size(-2) - 1 :-1].reshape(-1, logits.size(-1)),\n",
810
  " tokens[\"input_ids\"].reshape(-1),\n",
811
  " )\n",
812
  "\n",
813
  " # Backpropagate the gradients\n",
814
  " loss.backward()\n",
815
  "\n",
816
- " # Optimize the soft prompt\n",
817
  " optimizer.step()\n",
818
  "\n",
819
  " if step % 10 == 0:\n",
820
- " # Get the L2 norm of the difference between the original and optimized soft prompts\n",
821
  " l2_norm = sum(\n",
822
  " torch.norm(optimized - original, p=2)\n",
823
- " for optimized, original in zip(ss_prompt, original_ss_prompt)\n",
824
  " )\n",
825
  " print(\n",
826
  " f\"Step {step}, Loss: {loss.item()}, L2 norm: {l2_norm.item()}, avg rank: {ranks.float().mean().item()}\"\n",
@@ -829,7 +139,7 @@
829
  " # Early stopping with patience\n",
830
  " if loss.item() < best_loss and loss.item() > epsilon:\n",
831
  " best_loss = loss.item()\n",
832
- " best_ss_prompt = ss_prompt.detach().clone()\n",
833
  " patience_counter = 0\n",
834
  " else:\n",
835
  " patience_counter += 1\n",
@@ -846,187 +156,21 @@
846
  },
847
  {
848
  "cell_type": "code",
849
- "execution_count": 7,
850
  "id": "cc9a6a2f",
851
  "metadata": {},
852
  "outputs": [],
853
  "source": [
854
- "# Save the best soft prompt\n",
855
- "torch.save(best_ss_prompt, \"best_ss_prompt.pt\")"
856
  ]
857
  },
858
  {
859
  "cell_type": "code",
860
- "execution_count": 8,
861
  "id": "3186747d",
862
  "metadata": {},
863
- "outputs": [
864
- {
865
- "name": "stderr",
866
- "output_type": "stream",
867
- "text": [
868
- "100%|██████████| 150/150 [00:10<00:00, 14.83it/s]\n"
869
- ]
870
- },
871
- {
872
- "name": "stdout",
873
- "output_type": "stream",
874
- "text": [
875
- "Reference: Hey this is a test text of me, Aldan, writing some random text in this text box. Will this work? Perhaps!\n",
876
- "Generated: Hey this is a test text of me, Aldan, writing some random text in this text box. Will this work? Perhaps! I'll have to test this out. Maybe. I'll have to test this out some more. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe\n",
877
- "'Hey':\tRank 0.00, probability: 71.02%, Reference: 'Hey'\n",
878
- "' this':\tRank 0.00, probability: 87.04%, Reference: ' this'\n",
879
- "' is':\tRank 0.00, probability: 99.65%, Reference: ' is'\n",
880
- "' a':\tRank 0.00, probability: 96.23%, Reference: ' a'\n",
881
- "' test':\tRank 0.00, probability: 90.31%, Reference: ' test'\n",
882
- "' text':\tRank 0.00, probability: 95.40%, Reference: ' text'\n",
883
- "' of':\tRank 0.00, probability: 62.63%, Reference: ' of'\n",
884
- "' me':\tRank 0.00, probability: 93.07%, Reference: ' me'\n",
885
- "',':\tRank 0.00, probability: 95.56%, Reference: ','\n",
886
- "' Ald':\tRank 0.00, probability: 71.15%, Reference: ' Ald'\n",
887
- "'an':\tRank 0.00, probability: 84.45%, Reference: 'an'\n",
888
- "',':\tRank 0.00, probability: 97.80%, Reference: ','\n",
889
- "' writing':\tRank 0.00, probability: 95.00%, Reference: ' writing'\n",
890
- "' some':\tRank 0.00, probability: 88.69%, Reference: ' some'\n",
891
- "' random':\tRank 0.00, probability: 97.07%, Reference: ' random'\n",
892
- "' text':\tRank 0.00, probability: 89.14%, Reference: ' text'\n",
893
- "' in':\tRank 0.00, probability: 90.19%, Reference: ' in'\n",
894
- "' this':\tRank 0.00, probability: 98.12%, Reference: ' this'\n",
895
- "' text':\tRank 0.00, probability: 96.47%, Reference: ' text'\n",
896
- "' box':\tRank 0.00, probability: 95.55%, Reference: ' box'\n",
897
- "'.':\tRank 0.00, probability: 97.38%, Reference: '.'\n",
898
- "' Will':\tRank 0.00, probability: 90.65%, Reference: ' Will'\n",
899
- "' this':\tRank 0.00, probability: 93.50%, Reference: ' this'\n",
900
- "' work':\tRank 0.00, probability: 49.78%, Reference: ' work'\n",
901
- "'?':\tRank 0.00, probability: 96.33%, Reference: '?'\n",
902
- "' Perhaps':\tRank 0.00, probability: 76.97%, Reference: ' Perhaps'\n",
903
- "'!':\tRank 0.00, probability: 42.45%, Reference: '!'\n",
904
- "' I':\tRank 0.00, probability: 11.67%, Reference: 'N/A'\n",
905
- "''ll':\tRank 0.00, probability: 10.19%, Reference: 'N/A'\n",
906
- "' have':\tRank 0.00, probability: 53.08%, Reference: 'N/A'\n",
907
- "' to':\tRank 0.00, probability: 77.45%, Reference: 'N/A'\n",
908
- "' test':\tRank 0.00, probability: 8.06%, Reference: 'N/A'\n",
909
- "' this':\tRank 0.00, probability: 40.56%, Reference: 'N/A'\n",
910
- "' out':\tRank 0.00, probability: 51.98%, Reference: 'N/A'\n",
911
- "'.':\tRank 0.00, probability: 26.31%, Reference: 'N/A'\n",
912
- "' Maybe':\tRank 0.00, probability: 27.11%, Reference: 'N/A'\n",
913
- "'.':\tRank 0.00, probability: 22.16%, Reference: 'N/A'\n",
914
- "' I':\tRank 0.00, probability: 20.00%, Reference: 'N/A'\n",
915
- "''ll':\tRank 0.00, probability: 21.35%, Reference: 'N/A'\n",
916
- "' have':\tRank 0.00, probability: 59.45%, Reference: 'N/A'\n",
917
- "' to':\tRank 0.00, probability: 74.45%, Reference: 'N/A'\n",
918
- "' test':\tRank 0.00, probability: 17.91%, Reference: 'N/A'\n",
919
- "' this':\tRank 0.00, probability: 70.72%, Reference: 'N/A'\n",
920
- "' out':\tRank 0.00, probability: 87.33%, Reference: 'N/A'\n",
921
- "' some':\tRank 0.00, probability: 28.23%, Reference: 'N/A'\n",
922
- "' more':\tRank 0.00, probability: 26.20%, Reference: 'N/A'\n",
923
- "'.':\tRank 0.00, probability: 22.73%, Reference: 'N/A'\n",
924
- "' Maybe':\tRank 0.00, probability: 36.98%, Reference: 'N/A'\n",
925
- "'.':\tRank 0.00, probability: 54.91%, Reference: 'N/A'\n",
926
- "' Maybe':\tRank 0.00, probability: 64.39%, Reference: 'N/A'\n",
927
- "'.':\tRank 0.00, probability: 65.85%, Reference: 'N/A'\n",
928
- "' Maybe':\tRank 0.00, probability: 65.80%, Reference: 'N/A'\n",
929
- "'.':\tRank 0.00, probability: 51.26%, Reference: 'N/A'\n",
930
- "' Maybe':\tRank 0.00, probability: 60.87%, Reference: 'N/A'\n",
931
- "'.':\tRank 0.00, probability: 35.56%, Reference: 'N/A'\n",
932
- "' Maybe':\tRank 0.00, probability: 48.26%, Reference: 'N/A'\n",
933
- "'.':\tRank 0.00, probability: 28.23%, Reference: 'N/A'\n",
934
- "' Maybe':\tRank 0.00, probability: 43.46%, Reference: 'N/A'\n",
935
- "'.':\tRank 0.00, probability: 23.53%, Reference: 'N/A'\n",
936
- "' Maybe':\tRank 0.00, probability: 42.47%, Reference: 'N/A'\n",
937
- "'.':\tRank 0.00, probability: 21.56%, Reference: 'N/A'\n",
938
- "' Maybe':\tRank 0.00, probability: 44.99%, Reference: 'N/A'\n",
939
- "'.':\tRank 0.00, probability: 22.61%, Reference: 'N/A'\n",
940
- "' Maybe':\tRank 0.00, probability: 48.24%, Reference: 'N/A'\n",
941
- "'.':\tRank 0.00, probability: 19.47%, Reference: 'N/A'\n",
942
- "' Maybe':\tRank 0.00, probability: 51.23%, Reference: 'N/A'\n",
943
- "'.':\tRank 0.00, probability: 19.13%, Reference: 'N/A'\n",
944
- "' Maybe':\tRank 0.00, probability: 54.27%, Reference: 'N/A'\n",
945
- "'.':\tRank 0.00, probability: 20.28%, Reference: 'N/A'\n",
946
- "' Maybe':\tRank 0.00, probability: 56.09%, Reference: 'N/A'\n",
947
- "'.':\tRank 0.00, probability: 23.08%, Reference: 'N/A'\n",
948
- "' Maybe':\tRank 0.00, probability: 60.23%, Reference: 'N/A'\n",
949
- "'.':\tRank 0.00, probability: 25.86%, Reference: 'N/A'\n",
950
- "' Maybe':\tRank 0.00, probability: 63.49%, Reference: 'N/A'\n",
951
- "'.':\tRank 0.00, probability: 28.66%, Reference: 'N/A'\n",
952
- "' Maybe':\tRank 0.00, probability: 65.52%, Reference: 'N/A'\n",
953
- "'.':\tRank 0.00, probability: 29.81%, Reference: 'N/A'\n",
954
- "' Maybe':\tRank 0.00, probability: 67.58%, Reference: 'N/A'\n",
955
- "'.':\tRank 0.00, probability: 32.28%, Reference: 'N/A'\n",
956
- "' Maybe':\tRank 0.00, probability: 69.40%, Reference: 'N/A'\n",
957
- "'.':\tRank 0.00, probability: 33.50%, Reference: 'N/A'\n",
958
- "' Maybe':\tRank 0.00, probability: 70.15%, Reference: 'N/A'\n",
959
- "'.':\tRank 0.00, probability: 34.01%, Reference: 'N/A'\n",
960
- "' Maybe':\tRank 0.00, probability: 71.20%, Reference: 'N/A'\n",
961
- "'.':\tRank 0.00, probability: 36.83%, Reference: 'N/A'\n",
962
- "' Maybe':\tRank 0.00, probability: 71.29%, Reference: 'N/A'\n",
963
- "'.':\tRank 0.00, probability: 38.54%, Reference: 'N/A'\n",
964
- "' Maybe':\tRank 0.00, probability: 72.09%, Reference: 'N/A'\n",
965
- "'.':\tRank 0.00, probability: 40.10%, Reference: 'N/A'\n",
966
- "' Maybe':\tRank 0.00, probability: 71.66%, Reference: 'N/A'\n",
967
- "'.':\tRank 0.00, probability: 41.78%, Reference: 'N/A'\n",
968
- "' Maybe':\tRank 0.00, probability: 72.04%, Reference: 'N/A'\n",
969
- "'.':\tRank 0.00, probability: 42.55%, Reference: 'N/A'\n",
970
- "' Maybe':\tRank 0.00, probability: 71.18%, Reference: 'N/A'\n",
971
- "'.':\tRank 0.00, probability: 44.33%, Reference: 'N/A'\n",
972
- "' Maybe':\tRank 0.00, probability: 70.62%, Reference: 'N/A'\n",
973
- "'.':\tRank 0.00, probability: 44.84%, Reference: 'N/A'\n",
974
- "' Maybe':\tRank 0.00, probability: 70.70%, Reference: 'N/A'\n",
975
- "'.':\tRank 0.00, probability: 44.89%, Reference: 'N/A'\n",
976
- "' Maybe':\tRank 0.00, probability: 71.80%, Reference: 'N/A'\n",
977
- "'.':\tRank 0.00, probability: 46.26%, Reference: 'N/A'\n",
978
- "' Maybe':\tRank 0.00, probability: 71.88%, Reference: 'N/A'\n",
979
- "'.':\tRank 0.00, probability: 46.11%, Reference: 'N/A'\n",
980
- "' Maybe':\tRank 0.00, probability: 72.24%, Reference: 'N/A'\n",
981
- "'.':\tRank 0.00, probability: 44.98%, Reference: 'N/A'\n",
982
- "' Maybe':\tRank 0.00, probability: 73.18%, Reference: 'N/A'\n",
983
- "'.':\tRank 0.00, probability: 45.08%, Reference: 'N/A'\n",
984
- "' Maybe':\tRank 0.00, probability: 75.74%, Reference: 'N/A'\n",
985
- "'.':\tRank 0.00, probability: 45.33%, Reference: 'N/A'\n",
986
- "' Maybe':\tRank 0.00, probability: 76.68%, Reference: 'N/A'\n",
987
- "'.':\tRank 0.00, probability: 46.90%, Reference: 'N/A'\n",
988
- "' Maybe':\tRank 0.00, probability: 78.19%, Reference: 'N/A'\n",
989
- "'.':\tRank 0.00, probability: 46.31%, Reference: 'N/A'\n",
990
- "' Maybe':\tRank 0.00, probability: 78.19%, Reference: 'N/A'\n",
991
- "'.':\tRank 0.00, probability: 48.19%, Reference: 'N/A'\n",
992
- "' Maybe':\tRank 0.00, probability: 79.38%, Reference: 'N/A'\n",
993
- "'.':\tRank 0.00, probability: 48.56%, Reference: 'N/A'\n",
994
- "' Maybe':\tRank 0.00, probability: 80.29%, Reference: 'N/A'\n",
995
- "'.':\tRank 0.00, probability: 49.49%, Reference: 'N/A'\n",
996
- "' Maybe':\tRank 0.00, probability: 79.51%, Reference: 'N/A'\n",
997
- "'.':\tRank 0.00, probability: 51.13%, Reference: 'N/A'\n",
998
- "' Maybe':\tRank 0.00, probability: 81.70%, Reference: 'N/A'\n",
999
- "'.':\tRank 0.00, probability: 52.29%, Reference: 'N/A'\n",
1000
- "' Maybe':\tRank 0.00, probability: 80.87%, Reference: 'N/A'\n",
1001
- "'.':\tRank 0.00, probability: 52.84%, Reference: 'N/A'\n",
1002
- "' Maybe':\tRank 0.00, probability: 81.61%, Reference: 'N/A'\n",
1003
- "'.':\tRank 0.00, probability: 54.15%, Reference: 'N/A'\n",
1004
- "' Maybe':\tRank 0.00, probability: 81.67%, Reference: 'N/A'\n",
1005
- "'.':\tRank 0.00, probability: 55.17%, Reference: 'N/A'\n",
1006
- "' Maybe':\tRank 0.00, probability: 82.23%, Reference: 'N/A'\n",
1007
- "'.':\tRank 0.00, probability: 55.47%, Reference: 'N/A'\n",
1008
- "' Maybe':\tRank 0.00, probability: 82.99%, Reference: 'N/A'\n",
1009
- "'.':\tRank 0.00, probability: 57.04%, Reference: 'N/A'\n",
1010
- "' Maybe':\tRank 0.00, probability: 83.32%, Reference: 'N/A'\n",
1011
- "'.':\tRank 0.00, probability: 58.06%, Reference: 'N/A'\n",
1012
- "' Maybe':\tRank 0.00, probability: 84.47%, Reference: 'N/A'\n",
1013
- "'.':\tRank 0.00, probability: 59.85%, Reference: 'N/A'\n",
1014
- "' Maybe':\tRank 0.00, probability: 83.82%, Reference: 'N/A'\n",
1015
- "'.':\tRank 0.00, probability: 58.62%, Reference: 'N/A'\n",
1016
- "' Maybe':\tRank 0.00, probability: 85.14%, Reference: 'N/A'\n",
1017
- "'.':\tRank 0.00, probability: 60.26%, Reference: 'N/A'\n",
1018
- "' Maybe':\tRank 0.00, probability: 85.67%, Reference: 'N/A'\n",
1019
- "'.':\tRank 0.00, probability: 62.56%, Reference: 'N/A'\n",
1020
- "' Maybe':\tRank 0.00, probability: 86.05%, Reference: 'N/A'\n",
1021
- "'.':\tRank 0.00, probability: 64.28%, Reference: 'N/A'\n",
1022
- "' Maybe':\tRank 0.00, probability: 86.93%, Reference: 'N/A'\n",
1023
- "'.':\tRank 0.00, probability: 67.84%, Reference: 'N/A'\n",
1024
- "' Maybe':\tRank 0.00, probability: 87.52%, Reference: 'N/A'\n",
1025
- "'.':\tRank 0.00, probability: 73.01%, Reference: 'N/A'\n",
1026
- "' Maybe':\tRank 0.00, probability: 88.17%, Reference: 'N/A'\n"
1027
- ]
1028
- }
1029
- ],
1030
  "source": [
1031
  "import os\n",
1032
  "import transformers\n",
@@ -1042,24 +186,24 @@
1042
  " \"openai-community/gpt2\"\n",
1043
  " )\n",
1044
  "\n",
1045
- "# Load the best soft prompt from file\n",
1046
- "best_ss_prompt = torch.load(\"best_ss_prompt.pt\")\n",
1047
  "\n",
1048
  "# Do a greedy decoding manually\n",
1049
- "# Pass the optimized soft prompt to the model\n",
1050
  "# We can't use .generate() since we need to pass the inputs_embeds\n",
1051
  "all_logits = []\n",
1052
  "generated_tokens = []\n",
1053
  "with torch.no_grad():\n",
1054
  " for i in tqdm.tqdm(range(150)):\n",
1055
  " # Generate the logits for the next token using what we've generated so far\n",
1056
- " # If there are no generated tokens yet, just take the soft prompt\n",
1057
  " if len(generated_tokens) == 0:\n",
1058
- " expanded_prompt = best_ss_prompt\n",
1059
  " else:\n",
1060
  " expanded_prompt = torch.cat(\n",
1061
  " [\n",
1062
- " ss_prompt,\n",
1063
  " model.transformer.wte(\n",
1064
  " torch.tensor(generated_tokens).unsqueeze(0)\n",
1065
  " ).detach(),\n",
@@ -1069,9 +213,9 @@
1069
  "\n",
1070
  " assert expanded_prompt.shape == (\n",
1071
  " 1,\n",
1072
- " best_ss_prompt.size(-2) + len(generated_tokens),\n",
1073
  " model.config.n_embd,\n",
1074
- " ), f\"Got size {expanded_prompt.shape} instead of (1, {best_ss_prompt.size(-2) + len(generated_tokens)}, {model.config.n_embd})\"\n",
1075
  "\n",
1076
  " # Generate the logits for the text\n",
1077
  " logits: torch.Tensor = model(inputs_embeds=expanded_prompt).logits\n",
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": null,
6
  "id": "d1daab37",
7
  "metadata": {},
8
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  "source": [
10
  "import ipywidgets\n",
11
  "\n",
 
20
  },
21
  {
22
  "cell_type": "code",
23
+ "execution_count": null,
24
  "id": "09b3c097",
25
  "metadata": {},
26
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  "source": [
28
  "import transformers\n",
29
  "import torch\n",
 
50
  "\n",
51
  "embeddings = model.transformer.wte(tokens[\"input_ids\"]).detach()\n",
52
  "\n",
53
+ "# We'll use a Simon Says prompt - a special token and its hidden states for the first token that we'll use to condition the model. We'll optimize the hidden states of this token to maximize the likelihood of the text that follows it.\n",
54
+ "# Generate a Simon Says prompt by creating random hidden states for the first token\n",
55
  "# We'll optimize these hidden states to maximize the likelihood of the text that follows it\n",
56
  "# past_key_values (Tuple[Tuple[torch.Tensor]] of length config.n_layers) — Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see past_key_values output below). Can be used to speed up sequential decoding. The input_ids which have their past given to this model should not be passed as input_ids as they have already been computed.\n",
57
  "# Shape: past_key_values (Tuple[Tuple[torch.Tensor]] of length config.n_layers)\n",
58
  "# with each tuple having 2 tensors of shape (batch_size, num_heads, sequence_length, embed_size_per_head))\n",
59
  "# The precision of the hidden states should be fp16\n",
60
+ "simon_says_prompt = (\n",
61
  " # One tensor of shape (1, 1, embed_size_per_head)\n",
62
  " torch.randn(\n",
63
  " 1,\n",
64
  " 4,\n",
65
  " model.config.n_embd,\n",
66
  " requires_grad=True,\n",
 
67
  " )\n",
 
68
  ")\n",
69
  "\n",
70
+ "# Copy the Simon Says prompts since we'll optimize them and they'll change\n",
71
+ "original_simon_says_prompt = tuple(key.detach().clone() for key in simon_says_prompt)\n",
72
  "\n",
73
  "# Define the optimizer\n",
74
+ "optimizer = torch.optim.Adam([simon_says_prompt], lr=1e-2)\n",
75
  "\n",
76
  "# Define the loss function\n",
77
  "loss_fn = torch.nn.CrossEntropyLoss()\n",
 
85
  "# Freeze the model parameters\n",
86
  "model.eval()\n",
87
  "\n",
88
+ "# Check that the Simon Says prompt is optimizable (requires_grad=True)\n",
89
+ "for key in simon_says_prompt:\n",
90
  " assert key.requires_grad\n",
91
  "\n",
92
  "# Disable gradient computation for the model\n",
 
94
  " param.requires_grad = False\n",
95
  "\n",
96
  "best_loss = float(\"inf\")\n",
97
+ "best_simon_says_prompt = None\n",
98
  "\n",
99
+ "# Optimize the Simon Says prompt\n",
100
  "for step in tqdm.tqdm(range(n_steps)):\n",
101
  " # Zero the gradients\n",
102
  " optimizer.zero_grad()\n",
103
  "\n",
104
+ " # Add the optimizable Simon Says prompt to the embeddings\n",
105
+ " expanded_prompt = torch.cat([simon_says_prompt, embeddings], dim=1)\n",
106
  "\n",
107
  " # Generate the logits for the text\n",
108
  " logits: torch.Tensor = model(inputs_embeds=expanded_prompt).logits\n",
109
  "\n",
110
+ " probs = torch.softmax(logits[:, simon_says_prompt.size(-2) - 1 :-1], dim=-1)\n",
111
  "\n",
112
  " # Compute the ranks of the input IDs, i.e. how many tokens would have been more likely than the correct one (the label, the input IDs)\n",
113
  " \n",
 
116
  "\n",
117
  " # Compute the loss\n",
118
  " loss = loss_fn(\n",
119
+ " logits[:, simon_says_prompt.size(-2) - 1 :-1].reshape(-1, logits.size(-1)),\n",
120
  " tokens[\"input_ids\"].reshape(-1),\n",
121
  " )\n",
122
  "\n",
123
  " # Backpropagate the gradients\n",
124
  " loss.backward()\n",
125
  "\n",
126
+ " # Optimize the Simon Says prompt\n",
127
  " optimizer.step()\n",
128
  "\n",
129
  " if step % 10 == 0:\n",
130
+ " # Get the L2 norm of the difference between the original and optimized Simon Says prompts\n",
131
  " l2_norm = sum(\n",
132
  " torch.norm(optimized - original, p=2)\n",
133
+ " for optimized, original in zip(simon_says_prompt, original_simon_says_prompt)\n",
134
  " )\n",
135
  " print(\n",
136
  " f\"Step {step}, Loss: {loss.item()}, L2 norm: {l2_norm.item()}, avg rank: {ranks.float().mean().item()}\"\n",
 
139
  " # Early stopping with patience\n",
140
  " if loss.item() < best_loss and loss.item() > epsilon:\n",
141
  " best_loss = loss.item()\n",
142
+ " best_simon_says_prompt = simon_says_prompt.detach().clone()\n",
143
  " patience_counter = 0\n",
144
  " else:\n",
145
  " patience_counter += 1\n",
 
156
  },
157
  {
158
  "cell_type": "code",
159
+ "execution_count": null,
160
  "id": "cc9a6a2f",
161
  "metadata": {},
162
  "outputs": [],
163
  "source": [
164
+ "# Save the best Simon Says prompt\n",
165
+ "torch.save(best_simon_says_prompt, \"best_simon_says_prompt.pt\")"
166
  ]
167
  },
168
  {
169
  "cell_type": "code",
170
+ "execution_count": null,
171
  "id": "3186747d",
172
  "metadata": {},
173
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  "source": [
175
  "import os\n",
176
  "import transformers\n",
 
186
  " \"openai-community/gpt2\"\n",
187
  " )\n",
188
  "\n",
189
+ "# Load the best Simon Says prompt from file\n",
190
+ "best_simon_says_prompt = torch.load(\"best_simon_says_prompt.pt\")\n",
191
  "\n",
192
  "# Do a greedy decoding manually\n",
193
+ "# Pass the optimized Simon Says prompt to the model\n",
194
  "# We can't use .generate() since we need to pass the inputs_embeds\n",
195
  "all_logits = []\n",
196
  "generated_tokens = []\n",
197
  "with torch.no_grad():\n",
198
  " for i in tqdm.tqdm(range(150)):\n",
199
  " # Generate the logits for the next token using what we've generated so far\n",
200
+ " # If there are no generated tokens yet, just take the Simon Says prompt\n",
201
  " if len(generated_tokens) == 0:\n",
202
+ " expanded_prompt = best_simon_says_prompt\n",
203
  " else:\n",
204
  " expanded_prompt = torch.cat(\n",
205
  " [\n",
206
+ " simon_says_prompt,\n",
207
  " model.transformer.wte(\n",
208
  " torch.tensor(generated_tokens).unsqueeze(0)\n",
209
  " ).detach(),\n",
 
213
  "\n",
214
  " assert expanded_prompt.shape == (\n",
215
  " 1,\n",
216
+ " best_simon_says_prompt.size(-2) + len(generated_tokens),\n",
217
  " model.config.n_embd,\n",
218
+ " ), f\"Got size {expanded_prompt.shape} instead of (1, {best_simon_says_prompt.size(-2) + len(generated_tokens)}, {model.config.n_embd})\"\n",
219
  "\n",
220
  " # Generate the logits for the text\n",
221
  " logits: torch.Tensor = model(inputs_embeds=expanded_prompt).logits\n",
requirements.txt CHANGED
@@ -1,4 +1,6 @@
1
- torch
2
- transformers
 
3
  tqdm
4
- ipywidgets
 
 
1
+ torch==2.6.0
2
+ transformers==4.51.2
3
+ gradio==5.26.0
4
  tqdm
5
+ ipywidgets
6
+ matplotlib