Spaces:
Sleeping
Sleeping
ACMCMC
commited on
Commit
·
f856f17
1
Parent(s):
0331d85
Gradio demo
Browse files- README.md +12 -2
- demo.py +318 -0
- main.ipynb +35 -891
- requirements.txt +5 -3
README.md
CHANGED
@@ -1,4 +1,14 @@
|
|
1 |
-
|
2 |
-
Simon Says Prompts
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
This repo is a demo of how to generate an embedding that gets fed into a given LLM and is optimized to make the LLM output a specific text under a greedy decoding scheme.
|
|
|
1 |
+
---
|
2 |
+
title: "Simon Says Prompts"
|
3 |
+
emoji: "🗣️"
|
4 |
+
colorFrom: "blue"
|
5 |
+
colorTo: "pink"
|
6 |
+
sdk: "gradio"
|
7 |
+
sdk_version: "5.26.0"
|
8 |
+
app_file: "demo.py"
|
9 |
+
pinned: false
|
10 |
+
---
|
11 |
+
|
12 |
+
# Simon Says Prompts
|
13 |
|
14 |
This repo is a demo of how to generate an embedding that gets fed into a given LLM and is optimized to make the LLM output a specific text under a greedy decoding scheme.
|
demo.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import transformers
|
4 |
+
import time
|
5 |
+
from safetensors.torch import save_file, load_file
|
6 |
+
import tempfile
|
7 |
+
from io import BytesIO
|
8 |
+
import logging
|
9 |
+
|
10 |
+
# Load the tokenizer and model
|
11 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained("openai-community/gpt2")
|
12 |
+
model = transformers.AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
|
13 |
+
|
14 |
+
|
15 |
+
# Update the optimization function in demo.py to align with the notebook
|
16 |
+
|
17 |
+
|
18 |
+
def optimize_simon_says_prompt(
|
19 |
+
input_text: str,
|
20 |
+
number_of_simon_says_tokens: int,
|
21 |
+
n_steps: int,
|
22 |
+
lr: float,
|
23 |
+
progress=gr.Progress(track_tqdm=False), # Gradio progress tracking
|
24 |
+
) -> tuple[str, torch.Tensor]:
|
25 |
+
"""
|
26 |
+
Optimize a Simon Says prompt based on the input text and display the optimization process.
|
27 |
+
|
28 |
+
Parameters:
|
29 |
+
input_text (str): The input text provided by the user.
|
30 |
+
number_of_simon_says_tokens (int): Number of Simon Says tokens to optimize.
|
31 |
+
n_steps (int): Number of optimization steps.
|
32 |
+
lr (float): Learning rate for the optimization process.
|
33 |
+
progress (gr.Progress): Gradio progress tracking.
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
The optimized Simon Says prompt
|
37 |
+
"""
|
38 |
+
# Tokenize the input text
|
39 |
+
tokens = tokenizer(
|
40 |
+
input_text,
|
41 |
+
return_tensors="pt",
|
42 |
+
padding=False,
|
43 |
+
truncation=True,
|
44 |
+
add_special_tokens=True,
|
45 |
+
)
|
46 |
+
embeddings = model.transformer.wte(tokens["input_ids"]).detach()
|
47 |
+
|
48 |
+
# Initialize a random Simon Says prompt
|
49 |
+
simon_says_prompt = torch.randn(
|
50 |
+
1, number_of_simon_says_tokens, model.config.n_embd, requires_grad=True
|
51 |
+
)
|
52 |
+
optimizer = torch.optim.Adam([simon_says_prompt], lr=lr)
|
53 |
+
loss_fn = torch.nn.CrossEntropyLoss()
|
54 |
+
|
55 |
+
best_loss: float = float("inf")
|
56 |
+
best_simon_says_prompt: torch.Tensor = None
|
57 |
+
|
58 |
+
progress(0, desc="Starting optimization...")
|
59 |
+
time.sleep(1)
|
60 |
+
|
61 |
+
for step in range(n_steps):
|
62 |
+
optimizer.zero_grad()
|
63 |
+
expanded_prompt = torch.cat([simon_says_prompt, embeddings], dim=1)
|
64 |
+
logits = model(inputs_embeds=expanded_prompt).logits
|
65 |
+
probs = torch.softmax(logits[:, simon_says_prompt.size(-2) - 1 : -1], dim=-1)
|
66 |
+
ranks = (
|
67 |
+
torch.sum(
|
68 |
+
probs > probs.gather(2, tokens["input_ids"].unsqueeze(-1)), dim=-1
|
69 |
+
)
|
70 |
+
+ 1
|
71 |
+
)
|
72 |
+
loss = loss_fn(
|
73 |
+
logits[:, simon_says_prompt.size(-2) - 1 : -1].reshape(-1, logits.size(-1)),
|
74 |
+
tokens["input_ids"].reshape(-1),
|
75 |
+
)
|
76 |
+
loss.backward()
|
77 |
+
optimizer.step()
|
78 |
+
|
79 |
+
avg_rank = ranks.float().mean().item()
|
80 |
+
progress(
|
81 |
+
step / n_steps,
|
82 |
+
desc=f"Step {step}, Loss: {loss.item():.4f}, Avg Rank: {avg_rank:.2f}, Max Rank: {ranks.max().item()}",
|
83 |
+
)
|
84 |
+
|
85 |
+
logging.info(
|
86 |
+
f"Step {step}, Loss: {loss.item():.4f}, Avg Rank: {avg_rank:.2f}, Max Rank: {ranks.max().item()}"
|
87 |
+
)
|
88 |
+
|
89 |
+
if loss.item() < best_loss:
|
90 |
+
best_loss = loss.item()
|
91 |
+
best_simon_says_prompt = simon_says_prompt.detach().clone()
|
92 |
+
|
93 |
+
# If all ranks are 1, stop the optimization (perfect prediction)
|
94 |
+
if torch.all(ranks == 1):
|
95 |
+
break
|
96 |
+
|
97 |
+
return best_simon_says_prompt
|
98 |
+
|
99 |
+
|
100 |
+
# Modify the download_tensor function to save the tensor as a safetensors file
|
101 |
+
|
102 |
+
|
103 |
+
def download_tensor(tensor):
|
104 |
+
"""
|
105 |
+
Save a tensor to a safetensors file for download.
|
106 |
+
|
107 |
+
Parameters:
|
108 |
+
tensor (torch.Tensor): The tensor to be saved.
|
109 |
+
|
110 |
+
Returns:
|
111 |
+
str: The file path of the saved tensor.
|
112 |
+
"""
|
113 |
+
file_path = "optimized_tensor.safetensors"
|
114 |
+
save_file({"optimized_tensor": tensor}, file_path)
|
115 |
+
return file_path
|
116 |
+
|
117 |
+
|
118 |
+
def upload_tensor(file):
|
119 |
+
"""
|
120 |
+
Load a tensor from an uploaded safetensors file.
|
121 |
+
|
122 |
+
Parameters:
|
123 |
+
file (bytes): The uploaded file containing the safetensors data.
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
torch.Tensor: The loaded tensor.
|
127 |
+
|
128 |
+
Raises:
|
129 |
+
ValueError: If the safetensors file is invalid or the header is too large.
|
130 |
+
"""
|
131 |
+
if isinstance(file, bytes):
|
132 |
+
file = BytesIO(file) # Wrap bytes in a BytesIO object
|
133 |
+
|
134 |
+
with tempfile.NamedTemporaryFile(delete=True) as temp_file:
|
135 |
+
temp_file.write(
|
136 |
+
file.read()
|
137 |
+
) # Directly write the BytesIO content to the temporary file
|
138 |
+
temp_file.flush()
|
139 |
+
|
140 |
+
try:
|
141 |
+
tensor_data = load_file(temp_file.name)
|
142 |
+
except Exception as e:
|
143 |
+
raise ValueError(f"Failed to load safetensors file: {e}")
|
144 |
+
|
145 |
+
if "optimized_tensor" not in tensor_data:
|
146 |
+
raise ValueError(
|
147 |
+
"The safetensors file does not contain the expected 'optimized_tensor' key."
|
148 |
+
)
|
149 |
+
|
150 |
+
return tensor_data["optimized_tensor"]
|
151 |
+
|
152 |
+
|
153 |
+
def greedy_decode_with_ss_prompt(
|
154 |
+
ss_prompt: torch.Tensor, progress=gr.Progress()
|
155 |
+
) -> str:
|
156 |
+
"""
|
157 |
+
Perform greedy decoding using an uploaded optimized tensor and input text.
|
158 |
+
|
159 |
+
Parameters:
|
160 |
+
ss_prompt (torch.Tensor): The uploaded optimized tensor.
|
161 |
+
progress (gr.Progress): Gradio progress tracking.
|
162 |
+
|
163 |
+
Returns:
|
164 |
+
str: The generated text.
|
165 |
+
"""
|
166 |
+
generated_tokens = []
|
167 |
+
all_logits = []
|
168 |
+
|
169 |
+
progress(0, desc="Starting greedy decoding...")
|
170 |
+
|
171 |
+
with torch.no_grad():
|
172 |
+
for i in progress.tqdm(range(150), desc="Decoding..."):
|
173 |
+
if len(generated_tokens) == 0:
|
174 |
+
expanded_prompt = ss_prompt
|
175 |
+
else:
|
176 |
+
expanded_prompt = torch.cat(
|
177 |
+
[
|
178 |
+
ss_prompt,
|
179 |
+
model.transformer.wte(
|
180 |
+
torch.tensor(generated_tokens).unsqueeze(0)
|
181 |
+
).detach(),
|
182 |
+
],
|
183 |
+
dim=1,
|
184 |
+
)
|
185 |
+
|
186 |
+
logits = model(inputs_embeds=expanded_prompt).logits
|
187 |
+
next_token_logits = logits[0, -1, :]
|
188 |
+
next_token = next_token_logits.argmax().item()
|
189 |
+
|
190 |
+
logging.info(
|
191 |
+
f"Step {i}, Next Token: {next_token}, Logit: {next_token_logits[next_token].item()}"
|
192 |
+
)
|
193 |
+
|
194 |
+
generated_tokens.append(next_token)
|
195 |
+
all_logits.append(next_token_logits)
|
196 |
+
|
197 |
+
if next_token == tokenizer.eos_token_id:
|
198 |
+
break
|
199 |
+
|
200 |
+
generated_tokens = torch.tensor(generated_tokens)
|
201 |
+
generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
202 |
+
|
203 |
+
return generated_text
|
204 |
+
|
205 |
+
|
206 |
+
def process_and_generate(
|
207 |
+
input_text: str,
|
208 |
+
number_of_simon_says_tokens: int,
|
209 |
+
n_steps: int,
|
210 |
+
lr: float,
|
211 |
+
) -> tuple[str, str]:
|
212 |
+
"""
|
213 |
+
Optimize the Simon Says prompt, display the optimization process, and generate text based on the input text.
|
214 |
+
|
215 |
+
Parameters:
|
216 |
+
input_text (str): The input text provided by the user.
|
217 |
+
number_of_simon_says_tokens (int): Number of Simon Says tokens to optimize.
|
218 |
+
n_steps (int): Number of optimization steps.
|
219 |
+
lr (float): Learning rate for the optimization process.
|
220 |
+
|
221 |
+
Returns:
|
222 |
+
tuple: The optimized Simon Says prompt and the greedy-decoded text.
|
223 |
+
"""
|
224 |
+
optimized_prompt = optimize_simon_says_prompt(
|
225 |
+
input_text=input_text,
|
226 |
+
number_of_simon_says_tokens=number_of_simon_says_tokens,
|
227 |
+
n_steps=n_steps,
|
228 |
+
lr=lr,
|
229 |
+
)
|
230 |
+
|
231 |
+
# Generate text using the optimized prompt
|
232 |
+
generated_text: str = greedy_decode_with_ss_prompt(optimized_prompt)
|
233 |
+
|
234 |
+
return (
|
235 |
+
generated_text,
|
236 |
+
download_tensor(optimized_prompt),
|
237 |
+
) # Return the optimized tensor for download
|
238 |
+
|
239 |
+
|
240 |
+
def process_with_uploaded_tensor(
|
241 |
+
input_text: str, uploaded_tensor: torch.Tensor
|
242 |
+
) -> tuple[str, str]:
|
243 |
+
"""
|
244 |
+
Process the uploaded tensor and generate text based on the input text.
|
245 |
+
|
246 |
+
Parameters:
|
247 |
+
input_text (str): The input text provided by the user.
|
248 |
+
uploaded_tensor (torch.Tensor): The uploaded optimized tensor.
|
249 |
+
|
250 |
+
Returns:
|
251 |
+
tuple: The generated text and the file path of the uploaded tensor.
|
252 |
+
"""
|
253 |
+
generated_text = greedy_decode_with_ss_prompt(uploaded_tensor)
|
254 |
+
return generated_text, None
|
255 |
+
|
256 |
+
|
257 |
+
theme = gr.themes.Soft(
|
258 |
+
primary_hue="fuchsia",
|
259 |
+
secondary_hue="cyan",
|
260 |
+
neutral_hue="gray",
|
261 |
+
radius_size="none",
|
262 |
+
font=[
|
263 |
+
gr.themes.GoogleFont("IBM Plex Sans"),
|
264 |
+
"ui-sans-serif",
|
265 |
+
"system-ui",
|
266 |
+
"sans-serif",
|
267 |
+
],
|
268 |
+
font_mono=[
|
269 |
+
gr.themes.GoogleFont("IBM Plex Mono"),
|
270 |
+
"ui-monospace",
|
271 |
+
"Consolas",
|
272 |
+
"monospace",
|
273 |
+
],
|
274 |
+
)
|
275 |
+
|
276 |
+
# Update the Gradio interface to include configurable parameters
|
277 |
+
demo = gr.Interface(
|
278 |
+
theme=theme,
|
279 |
+
title="Simon Says Prompt Optimization and Text Generation",
|
280 |
+
fn=lambda input_text, number_of_simon_says_tokens, n_steps, lr, uploaded_file: (
|
281 |
+
process_with_uploaded_tensor(input_text, upload_tensor(uploaded_file))
|
282 |
+
if uploaded_file
|
283 |
+
else process_and_generate(
|
284 |
+
input_text, number_of_simon_says_tokens, n_steps, lr
|
285 |
+
)
|
286 |
+
),
|
287 |
+
inputs=[
|
288 |
+
gr.Textbox(
|
289 |
+
lines=5,
|
290 |
+
placeholder="Enter your text here...",
|
291 |
+
label="Input Text",
|
292 |
+
value="Hello world! I'm Aldan, happy to be here.",
|
293 |
+
),
|
294 |
+
gr.Slider(
|
295 |
+
minimum=1, maximum=10, step=1, value=4, label="Number of Simon Says Tokens"
|
296 |
+
),
|
297 |
+
gr.Slider(
|
298 |
+
minimum=100,
|
299 |
+
maximum=10000,
|
300 |
+
step=100,
|
301 |
+
value=5000,
|
302 |
+
label="Number of Optimization Steps",
|
303 |
+
),
|
304 |
+
gr.Slider(
|
305 |
+
minimum=1e-5, maximum=1e-1, step=1e-5, value=1e-2, label="Learning Rate"
|
306 |
+
),
|
307 |
+
gr.File(label="Upload Optimized Tensor (Optional)", type="binary"),
|
308 |
+
],
|
309 |
+
outputs=[
|
310 |
+
gr.Textbox(label="Generated Text"),
|
311 |
+
gr.File(label="Download Optimized Tensor", type="filepath"),
|
312 |
+
],
|
313 |
+
description="This demo optimizes a Simon Says prompt based on your input text, displays the optimization process, and generates text using the optimized prompt. Optionally, you can upload a pre-optimized tensor for inference.",
|
314 |
+
)
|
315 |
+
|
316 |
+
# Ensure the Gradio interface is correctly launched
|
317 |
+
if __name__ == "__main__":
|
318 |
+
demo.launch(debug=True, show_error=True)
|
main.ipynb
CHANGED
@@ -2,25 +2,10 @@
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
-
"execution_count":
|
6 |
"id": "d1daab37",
|
7 |
"metadata": {},
|
8 |
-
"outputs": [
|
9 |
-
{
|
10 |
-
"data": {
|
11 |
-
"application/vnd.jupyter.widget-view+json": {
|
12 |
-
"model_id": "6be175a38e1843ab83831cb3a9a06296",
|
13 |
-
"version_major": 2,
|
14 |
-
"version_minor": 0
|
15 |
-
},
|
16 |
-
"text/plain": [
|
17 |
-
"Textarea(value=\"Solomonoff's theory of inductive inference proposes that all problems of logical induction can…"
|
18 |
-
]
|
19 |
-
},
|
20 |
-
"metadata": {},
|
21 |
-
"output_type": "display_data"
|
22 |
-
}
|
23 |
-
],
|
24 |
"source": [
|
25 |
"import ipywidgets\n",
|
26 |
"\n",
|
@@ -35,683 +20,10 @@
|
|
35 |
},
|
36 |
{
|
37 |
"cell_type": "code",
|
38 |
-
"execution_count":
|
39 |
"id": "09b3c097",
|
40 |
"metadata": {},
|
41 |
-
"outputs": [
|
42 |
-
{
|
43 |
-
"name": "stdout",
|
44 |
-
"output_type": "stream",
|
45 |
-
"text": [
|
46 |
-
"Optimizing: Hey this is a test text of me, Aldan, writing some random text in this text box. Will this work? Perhaps!\n"
|
47 |
-
]
|
48 |
-
},
|
49 |
-
{
|
50 |
-
"name": "stderr",
|
51 |
-
"output_type": "stream",
|
52 |
-
"text": [
|
53 |
-
" 0%| | 2/5000 [00:00<13:45, 6.05it/s]"
|
54 |
-
]
|
55 |
-
},
|
56 |
-
{
|
57 |
-
"name": "stdout",
|
58 |
-
"output_type": "stream",
|
59 |
-
"text": [
|
60 |
-
"Step 0, Loss: 5.8886637687683105, L2 norm: 0.5542519688606262, avg rank: 655.2222290039062\n"
|
61 |
-
]
|
62 |
-
},
|
63 |
-
{
|
64 |
-
"name": "stderr",
|
65 |
-
"output_type": "stream",
|
66 |
-
"text": [
|
67 |
-
" 0%| | 12/5000 [00:01<10:57, 7.58it/s]"
|
68 |
-
]
|
69 |
-
},
|
70 |
-
{
|
71 |
-
"name": "stdout",
|
72 |
-
"output_type": "stream",
|
73 |
-
"text": [
|
74 |
-
"Step 10, Loss: 5.159086227416992, L2 norm: 3.46441388130188, avg rank: 439.6666564941406\n"
|
75 |
-
]
|
76 |
-
},
|
77 |
-
{
|
78 |
-
"name": "stderr",
|
79 |
-
"output_type": "stream",
|
80 |
-
"text": [
|
81 |
-
" 0%| | 22/5000 [00:03<12:31, 6.62it/s]"
|
82 |
-
]
|
83 |
-
},
|
84 |
-
{
|
85 |
-
"name": "stdout",
|
86 |
-
"output_type": "stream",
|
87 |
-
"text": [
|
88 |
-
"Step 20, Loss: 4.893388748168945, L2 norm: 5.287822246551514, avg rank: 380.629638671875\n"
|
89 |
-
]
|
90 |
-
},
|
91 |
-
{
|
92 |
-
"name": "stderr",
|
93 |
-
"output_type": "stream",
|
94 |
-
"text": [
|
95 |
-
" 1%| | 32/5000 [00:04<10:37, 7.79it/s]"
|
96 |
-
]
|
97 |
-
},
|
98 |
-
{
|
99 |
-
"name": "stdout",
|
100 |
-
"output_type": "stream",
|
101 |
-
"text": [
|
102 |
-
"Step 30, Loss: 4.576204299926758, L2 norm: 7.122143745422363, avg rank: 313.0740661621094\n"
|
103 |
-
]
|
104 |
-
},
|
105 |
-
{
|
106 |
-
"name": "stderr",
|
107 |
-
"output_type": "stream",
|
108 |
-
"text": [
|
109 |
-
" 1%| | 42/5000 [00:05<10:42, 7.71it/s]"
|
110 |
-
]
|
111 |
-
},
|
112 |
-
{
|
113 |
-
"name": "stdout",
|
114 |
-
"output_type": "stream",
|
115 |
-
"text": [
|
116 |
-
"Step 40, Loss: 4.330563545227051, L2 norm: 8.640031814575195, avg rank: 189.8518524169922\n"
|
117 |
-
]
|
118 |
-
},
|
119 |
-
{
|
120 |
-
"name": "stderr",
|
121 |
-
"output_type": "stream",
|
122 |
-
"text": [
|
123 |
-
" 1%| | 52/5000 [00:07<13:39, 6.04it/s]"
|
124 |
-
]
|
125 |
-
},
|
126 |
-
{
|
127 |
-
"name": "stdout",
|
128 |
-
"output_type": "stream",
|
129 |
-
"text": [
|
130 |
-
"Step 50, Loss: 4.107393264770508, L2 norm: 10.037625312805176, avg rank: 130.6666717529297\n"
|
131 |
-
]
|
132 |
-
},
|
133 |
-
{
|
134 |
-
"name": "stderr",
|
135 |
-
"output_type": "stream",
|
136 |
-
"text": [
|
137 |
-
" 1%| | 62/5000 [00:08<11:33, 7.12it/s]"
|
138 |
-
]
|
139 |
-
},
|
140 |
-
{
|
141 |
-
"name": "stdout",
|
142 |
-
"output_type": "stream",
|
143 |
-
"text": [
|
144 |
-
"Step 60, Loss: 3.8990049362182617, L2 norm: 11.52373218536377, avg rank: 105.9259262084961\n"
|
145 |
-
]
|
146 |
-
},
|
147 |
-
{
|
148 |
-
"name": "stderr",
|
149 |
-
"output_type": "stream",
|
150 |
-
"text": [
|
151 |
-
" 1%|▏ | 72/5000 [00:10<19:46, 4.15it/s]"
|
152 |
-
]
|
153 |
-
},
|
154 |
-
{
|
155 |
-
"name": "stdout",
|
156 |
-
"output_type": "stream",
|
157 |
-
"text": [
|
158 |
-
"Step 70, Loss: 3.689441442489624, L2 norm: 12.885846138000488, avg rank: 96.77777862548828\n"
|
159 |
-
]
|
160 |
-
},
|
161 |
-
{
|
162 |
-
"name": "stderr",
|
163 |
-
"output_type": "stream",
|
164 |
-
"text": [
|
165 |
-
" 2%|▏ | 82/5000 [00:13<14:27, 5.67it/s]"
|
166 |
-
]
|
167 |
-
},
|
168 |
-
{
|
169 |
-
"name": "stdout",
|
170 |
-
"output_type": "stream",
|
171 |
-
"text": [
|
172 |
-
"Step 80, Loss: 3.664461374282837, L2 norm: 14.095556259155273, avg rank: 82.48148345947266\n"
|
173 |
-
]
|
174 |
-
},
|
175 |
-
{
|
176 |
-
"name": "stderr",
|
177 |
-
"output_type": "stream",
|
178 |
-
"text": [
|
179 |
-
" 2%|▏ | 92/5000 [00:14<12:06, 6.76it/s]"
|
180 |
-
]
|
181 |
-
},
|
182 |
-
{
|
183 |
-
"name": "stdout",
|
184 |
-
"output_type": "stream",
|
185 |
-
"text": [
|
186 |
-
"Step 90, Loss: 3.359441041946411, L2 norm: 15.264603614807129, avg rank: 76.62963104248047\n"
|
187 |
-
]
|
188 |
-
},
|
189 |
-
{
|
190 |
-
"name": "stderr",
|
191 |
-
"output_type": "stream",
|
192 |
-
"text": [
|
193 |
-
" 2%|▏ | 102/5000 [00:15<11:25, 7.14it/s]"
|
194 |
-
]
|
195 |
-
},
|
196 |
-
{
|
197 |
-
"name": "stdout",
|
198 |
-
"output_type": "stream",
|
199 |
-
"text": [
|
200 |
-
"Step 100, Loss: 3.2641189098358154, L2 norm: 16.456912994384766, avg rank: 59.55555725097656\n"
|
201 |
-
]
|
202 |
-
},
|
203 |
-
{
|
204 |
-
"name": "stderr",
|
205 |
-
"output_type": "stream",
|
206 |
-
"text": [
|
207 |
-
" 2%|▏ | 111/5000 [00:17<10:21, 7.87it/s]"
|
208 |
-
]
|
209 |
-
},
|
210 |
-
{
|
211 |
-
"name": "stdout",
|
212 |
-
"output_type": "stream",
|
213 |
-
"text": [
|
214 |
-
"Step 110, Loss: 3.4184069633483887, L2 norm: 17.494909286499023, avg rank: 45.22222137451172\n"
|
215 |
-
]
|
216 |
-
},
|
217 |
-
{
|
218 |
-
"name": "stderr",
|
219 |
-
"output_type": "stream",
|
220 |
-
"text": [
|
221 |
-
" 2%|▏ | 122/5000 [00:19<14:21, 5.66it/s]"
|
222 |
-
]
|
223 |
-
},
|
224 |
-
{
|
225 |
-
"name": "stdout",
|
226 |
-
"output_type": "stream",
|
227 |
-
"text": [
|
228 |
-
"Step 120, Loss: 3.057454824447632, L2 norm: 18.361045837402344, avg rank: 47.592594146728516\n"
|
229 |
-
]
|
230 |
-
},
|
231 |
-
{
|
232 |
-
"name": "stderr",
|
233 |
-
"output_type": "stream",
|
234 |
-
"text": [
|
235 |
-
" 3%|▎ | 132/5000 [00:20<11:23, 7.12it/s]"
|
236 |
-
]
|
237 |
-
},
|
238 |
-
{
|
239 |
-
"name": "stdout",
|
240 |
-
"output_type": "stream",
|
241 |
-
"text": [
|
242 |
-
"Step 130, Loss: 2.752321243286133, L2 norm: 19.076860427856445, avg rank: 37.185184478759766\n"
|
243 |
-
]
|
244 |
-
},
|
245 |
-
{
|
246 |
-
"name": "stderr",
|
247 |
-
"output_type": "stream",
|
248 |
-
"text": [
|
249 |
-
" 3%|▎ | 142/5000 [00:21<10:12, 7.93it/s]"
|
250 |
-
]
|
251 |
-
},
|
252 |
-
{
|
253 |
-
"name": "stdout",
|
254 |
-
"output_type": "stream",
|
255 |
-
"text": [
|
256 |
-
"Step 140, Loss: 2.492600679397583, L2 norm: 19.873262405395508, avg rank: 29.0\n"
|
257 |
-
]
|
258 |
-
},
|
259 |
-
{
|
260 |
-
"name": "stderr",
|
261 |
-
"output_type": "stream",
|
262 |
-
"text": [
|
263 |
-
" 3%|▎ | 152/5000 [00:23<12:10, 6.64it/s]"
|
264 |
-
]
|
265 |
-
},
|
266 |
-
{
|
267 |
-
"name": "stdout",
|
268 |
-
"output_type": "stream",
|
269 |
-
"text": [
|
270 |
-
"Step 150, Loss: 2.2693941593170166, L2 norm: 20.689470291137695, avg rank: 25.22222137451172\n"
|
271 |
-
]
|
272 |
-
},
|
273 |
-
{
|
274 |
-
"name": "stderr",
|
275 |
-
"output_type": "stream",
|
276 |
-
"text": [
|
277 |
-
" 3%|▎ | 162/5000 [00:24<11:26, 7.05it/s]"
|
278 |
-
]
|
279 |
-
},
|
280 |
-
{
|
281 |
-
"name": "stdout",
|
282 |
-
"output_type": "stream",
|
283 |
-
"text": [
|
284 |
-
"Step 160, Loss: 2.0094830989837646, L2 norm: 21.427452087402344, avg rank: 20.185184478759766\n"
|
285 |
-
]
|
286 |
-
},
|
287 |
-
{
|
288 |
-
"name": "stderr",
|
289 |
-
"output_type": "stream",
|
290 |
-
"text": [
|
291 |
-
" 3%|▎ | 172/5000 [00:26<11:22, 7.08it/s]"
|
292 |
-
]
|
293 |
-
},
|
294 |
-
{
|
295 |
-
"name": "stdout",
|
296 |
-
"output_type": "stream",
|
297 |
-
"text": [
|
298 |
-
"Step 170, Loss: 1.8882372379302979, L2 norm: 22.13208770751953, avg rank: 23.074073791503906\n"
|
299 |
-
]
|
300 |
-
},
|
301 |
-
{
|
302 |
-
"name": "stderr",
|
303 |
-
"output_type": "stream",
|
304 |
-
"text": [
|
305 |
-
" 4%|▎ | 182/5000 [00:27<09:56, 8.08it/s]"
|
306 |
-
]
|
307 |
-
},
|
308 |
-
{
|
309 |
-
"name": "stdout",
|
310 |
-
"output_type": "stream",
|
311 |
-
"text": [
|
312 |
-
"Step 180, Loss: 1.8304041624069214, L2 norm: 22.583101272583008, avg rank: 28.33333396911621\n"
|
313 |
-
]
|
314 |
-
},
|
315 |
-
{
|
316 |
-
"name": "stderr",
|
317 |
-
"output_type": "stream",
|
318 |
-
"text": [
|
319 |
-
" 4%|▍ | 192/5000 [00:28<10:58, 7.30it/s]"
|
320 |
-
]
|
321 |
-
},
|
322 |
-
{
|
323 |
-
"name": "stdout",
|
324 |
-
"output_type": "stream",
|
325 |
-
"text": [
|
326 |
-
"Step 190, Loss: 1.5684080123901367, L2 norm: 23.05557632446289, avg rank: 22.037036895751953\n"
|
327 |
-
]
|
328 |
-
},
|
329 |
-
{
|
330 |
-
"name": "stderr",
|
331 |
-
"output_type": "stream",
|
332 |
-
"text": [
|
333 |
-
" 4%|▍ | 202/5000 [00:30<09:58, 8.02it/s]"
|
334 |
-
]
|
335 |
-
},
|
336 |
-
{
|
337 |
-
"name": "stdout",
|
338 |
-
"output_type": "stream",
|
339 |
-
"text": [
|
340 |
-
"Step 200, Loss: 1.3705590963363647, L2 norm: 23.487506866455078, avg rank: 18.925926208496094\n"
|
341 |
-
]
|
342 |
-
},
|
343 |
-
{
|
344 |
-
"name": "stderr",
|
345 |
-
"output_type": "stream",
|
346 |
-
"text": [
|
347 |
-
" 4%|▍ | 212/5000 [00:31<09:58, 8.00it/s]"
|
348 |
-
]
|
349 |
-
},
|
350 |
-
{
|
351 |
-
"name": "stdout",
|
352 |
-
"output_type": "stream",
|
353 |
-
"text": [
|
354 |
-
"Step 210, Loss: 1.2061578035354614, L2 norm: 23.888507843017578, avg rank: 15.481481552124023\n"
|
355 |
-
]
|
356 |
-
},
|
357 |
-
{
|
358 |
-
"name": "stderr",
|
359 |
-
"output_type": "stream",
|
360 |
-
"text": [
|
361 |
-
" 4%|▍ | 222/5000 [00:32<10:36, 7.50it/s]"
|
362 |
-
]
|
363 |
-
},
|
364 |
-
{
|
365 |
-
"name": "stdout",
|
366 |
-
"output_type": "stream",
|
367 |
-
"text": [
|
368 |
-
"Step 220, Loss: 1.0748673677444458, L2 norm: 24.25291633605957, avg rank: 13.333333015441895\n"
|
369 |
-
]
|
370 |
-
},
|
371 |
-
{
|
372 |
-
"name": "stderr",
|
373 |
-
"output_type": "stream",
|
374 |
-
"text": [
|
375 |
-
" 5%|▍ | 232/5000 [00:34<11:45, 6.76it/s]"
|
376 |
-
]
|
377 |
-
},
|
378 |
-
{
|
379 |
-
"name": "stdout",
|
380 |
-
"output_type": "stream",
|
381 |
-
"text": [
|
382 |
-
"Step 230, Loss: 0.9695132374763489, L2 norm: 24.588470458984375, avg rank: 12.592592239379883\n"
|
383 |
-
]
|
384 |
-
},
|
385 |
-
{
|
386 |
-
"name": "stderr",
|
387 |
-
"output_type": "stream",
|
388 |
-
"text": [
|
389 |
-
" 5%|▍ | 242/5000 [00:35<11:47, 6.73it/s]"
|
390 |
-
]
|
391 |
-
},
|
392 |
-
{
|
393 |
-
"name": "stdout",
|
394 |
-
"output_type": "stream",
|
395 |
-
"text": [
|
396 |
-
"Step 240, Loss: 0.8693955540657043, L2 norm: 24.909320831298828, avg rank: 11.518518447875977\n"
|
397 |
-
]
|
398 |
-
},
|
399 |
-
{
|
400 |
-
"name": "stderr",
|
401 |
-
"output_type": "stream",
|
402 |
-
"text": [
|
403 |
-
" 5%|▌ | 252/5000 [00:36<09:53, 7.99it/s]"
|
404 |
-
]
|
405 |
-
},
|
406 |
-
{
|
407 |
-
"name": "stdout",
|
408 |
-
"output_type": "stream",
|
409 |
-
"text": [
|
410 |
-
"Step 250, Loss: 0.7695979475975037, L2 norm: 25.205760955810547, avg rank: 10.037036895751953\n"
|
411 |
-
]
|
412 |
-
},
|
413 |
-
{
|
414 |
-
"name": "stderr",
|
415 |
-
"output_type": "stream",
|
416 |
-
"text": [
|
417 |
-
" 5%|▌ | 262/5000 [00:38<10:06, 7.81it/s]"
|
418 |
-
]
|
419 |
-
},
|
420 |
-
{
|
421 |
-
"name": "stdout",
|
422 |
-
"output_type": "stream",
|
423 |
-
"text": [
|
424 |
-
"Step 260, Loss: 0.6803638935089111, L2 norm: 25.493518829345703, avg rank: 8.666666984558105\n"
|
425 |
-
]
|
426 |
-
},
|
427 |
-
{
|
428 |
-
"name": "stderr",
|
429 |
-
"output_type": "stream",
|
430 |
-
"text": [
|
431 |
-
" 5%|▌ | 272/5000 [00:39<11:42, 6.73it/s]"
|
432 |
-
]
|
433 |
-
},
|
434 |
-
{
|
435 |
-
"name": "stdout",
|
436 |
-
"output_type": "stream",
|
437 |
-
"text": [
|
438 |
-
"Step 270, Loss: 0.6074316501617432, L2 norm: 25.7545108795166, avg rank: 7.407407283782959\n"
|
439 |
-
]
|
440 |
-
},
|
441 |
-
{
|
442 |
-
"name": "stderr",
|
443 |
-
"output_type": "stream",
|
444 |
-
"text": [
|
445 |
-
" 6%|▌ | 282/5000 [00:40<10:20, 7.60it/s]"
|
446 |
-
]
|
447 |
-
},
|
448 |
-
{
|
449 |
-
"name": "stdout",
|
450 |
-
"output_type": "stream",
|
451 |
-
"text": [
|
452 |
-
"Step 280, Loss: 0.5503782629966736, L2 norm: 25.987356185913086, avg rank: 6.666666507720947\n"
|
453 |
-
]
|
454 |
-
},
|
455 |
-
{
|
456 |
-
"name": "stderr",
|
457 |
-
"output_type": "stream",
|
458 |
-
"text": [
|
459 |
-
" 6%|▌ | 292/5000 [00:42<10:42, 7.33it/s]"
|
460 |
-
]
|
461 |
-
},
|
462 |
-
{
|
463 |
-
"name": "stdout",
|
464 |
-
"output_type": "stream",
|
465 |
-
"text": [
|
466 |
-
"Step 290, Loss: 0.5049920678138733, L2 norm: 26.197214126586914, avg rank: 6.111111164093018\n"
|
467 |
-
]
|
468 |
-
},
|
469 |
-
{
|
470 |
-
"name": "stderr",
|
471 |
-
"output_type": "stream",
|
472 |
-
"text": [
|
473 |
-
" 6%|▌ | 302/5000 [00:43<11:42, 6.68it/s]"
|
474 |
-
]
|
475 |
-
},
|
476 |
-
{
|
477 |
-
"name": "stdout",
|
478 |
-
"output_type": "stream",
|
479 |
-
"text": [
|
480 |
-
"Step 300, Loss: 0.46782854199409485, L2 norm: 26.38979721069336, avg rank: 5.407407283782959\n"
|
481 |
-
]
|
482 |
-
},
|
483 |
-
{
|
484 |
-
"name": "stderr",
|
485 |
-
"output_type": "stream",
|
486 |
-
"text": [
|
487 |
-
" 6%|▌ | 312/5000 [00:45<10:34, 7.39it/s]"
|
488 |
-
]
|
489 |
-
},
|
490 |
-
{
|
491 |
-
"name": "stdout",
|
492 |
-
"output_type": "stream",
|
493 |
-
"text": [
|
494 |
-
"Step 310, Loss: 0.4365224838256836, L2 norm: 26.568010330200195, avg rank: 4.666666507720947\n"
|
495 |
-
]
|
496 |
-
},
|
497 |
-
{
|
498 |
-
"name": "stderr",
|
499 |
-
"output_type": "stream",
|
500 |
-
"text": [
|
501 |
-
" 6%|▋ | 322/5000 [00:46<10:26, 7.46it/s]"
|
502 |
-
]
|
503 |
-
},
|
504 |
-
{
|
505 |
-
"name": "stdout",
|
506 |
-
"output_type": "stream",
|
507 |
-
"text": [
|
508 |
-
"Step 320, Loss: 0.4172615706920624, L2 norm: 26.73436737060547, avg rank: 4.296296119689941\n"
|
509 |
-
]
|
510 |
-
},
|
511 |
-
{
|
512 |
-
"name": "stderr",
|
513 |
-
"output_type": "stream",
|
514 |
-
"text": [
|
515 |
-
" 7%|▋ | 332/5000 [00:48<11:34, 6.73it/s]"
|
516 |
-
]
|
517 |
-
},
|
518 |
-
{
|
519 |
-
"name": "stdout",
|
520 |
-
"output_type": "stream",
|
521 |
-
"text": [
|
522 |
-
"Step 330, Loss: 0.44890207052230835, L2 norm: 26.890453338623047, avg rank: 4.777777671813965\n"
|
523 |
-
]
|
524 |
-
},
|
525 |
-
{
|
526 |
-
"name": "stderr",
|
527 |
-
"output_type": "stream",
|
528 |
-
"text": [
|
529 |
-
" 7%|▋ | 342/5000 [00:49<09:44, 7.97it/s]"
|
530 |
-
]
|
531 |
-
},
|
532 |
-
{
|
533 |
-
"name": "stdout",
|
534 |
-
"output_type": "stream",
|
535 |
-
"text": [
|
536 |
-
"Step 340, Loss: 1.6345257759094238, L2 norm: 27.1015625, avg rank: 10.0\n"
|
537 |
-
]
|
538 |
-
},
|
539 |
-
{
|
540 |
-
"name": "stderr",
|
541 |
-
"output_type": "stream",
|
542 |
-
"text": [
|
543 |
-
" 7%|▋ | 352/5000 [00:50<10:12, 7.59it/s]"
|
544 |
-
]
|
545 |
-
},
|
546 |
-
{
|
547 |
-
"name": "stdout",
|
548 |
-
"output_type": "stream",
|
549 |
-
"text": [
|
550 |
-
"Step 350, Loss: 1.079402208328247, L2 norm: 27.411144256591797, avg rank: 6.592592716217041\n"
|
551 |
-
]
|
552 |
-
},
|
553 |
-
{
|
554 |
-
"name": "stderr",
|
555 |
-
"output_type": "stream",
|
556 |
-
"text": [
|
557 |
-
" 7%|▋ | 362/5000 [00:52<12:10, 6.35it/s]"
|
558 |
-
]
|
559 |
-
},
|
560 |
-
{
|
561 |
-
"name": "stdout",
|
562 |
-
"output_type": "stream",
|
563 |
-
"text": [
|
564 |
-
"Step 360, Loss: 0.8217418193817139, L2 norm: 27.712562561035156, avg rank: 4.703703880310059\n"
|
565 |
-
]
|
566 |
-
},
|
567 |
-
{
|
568 |
-
"name": "stderr",
|
569 |
-
"output_type": "stream",
|
570 |
-
"text": [
|
571 |
-
" 7%|▋ | 372/5000 [00:53<11:04, 6.97it/s]"
|
572 |
-
]
|
573 |
-
},
|
574 |
-
{
|
575 |
-
"name": "stdout",
|
576 |
-
"output_type": "stream",
|
577 |
-
"text": [
|
578 |
-
"Step 370, Loss: 0.6470327973365784, L2 norm: 28.024415969848633, avg rank: 1.8888888359069824\n"
|
579 |
-
]
|
580 |
-
},
|
581 |
-
{
|
582 |
-
"name": "stderr",
|
583 |
-
"output_type": "stream",
|
584 |
-
"text": [
|
585 |
-
" 8%|▊ | 382/5000 [00:55<10:21, 7.43it/s]"
|
586 |
-
]
|
587 |
-
},
|
588 |
-
{
|
589 |
-
"name": "stdout",
|
590 |
-
"output_type": "stream",
|
591 |
-
"text": [
|
592 |
-
"Step 380, Loss: 1.0997235774993896, L2 norm: 28.2260684967041, avg rank: 7.185184955596924\n"
|
593 |
-
]
|
594 |
-
},
|
595 |
-
{
|
596 |
-
"name": "stderr",
|
597 |
-
"output_type": "stream",
|
598 |
-
"text": [
|
599 |
-
" 8%|▊ | 392/5000 [00:56<10:58, 7.00it/s]"
|
600 |
-
]
|
601 |
-
},
|
602 |
-
{
|
603 |
-
"name": "stdout",
|
604 |
-
"output_type": "stream",
|
605 |
-
"text": [
|
606 |
-
"Step 390, Loss: 0.64601731300354, L2 norm: 28.461565017700195, avg rank: 2.9259259700775146\n"
|
607 |
-
]
|
608 |
-
},
|
609 |
-
{
|
610 |
-
"name": "stderr",
|
611 |
-
"output_type": "stream",
|
612 |
-
"text": [
|
613 |
-
" 8%|▊ | 402/5000 [00:58<09:58, 7.68it/s]"
|
614 |
-
]
|
615 |
-
},
|
616 |
-
{
|
617 |
-
"name": "stdout",
|
618 |
-
"output_type": "stream",
|
619 |
-
"text": [
|
620 |
-
"Step 400, Loss: 0.482523649930954, L2 norm: 28.71732521057129, avg rank: 1.5185185670852661\n"
|
621 |
-
]
|
622 |
-
},
|
623 |
-
{
|
624 |
-
"name": "stderr",
|
625 |
-
"output_type": "stream",
|
626 |
-
"text": [
|
627 |
-
" 8%|▊ | 412/5000 [00:59<09:54, 7.71it/s]"
|
628 |
-
]
|
629 |
-
},
|
630 |
-
{
|
631 |
-
"name": "stdout",
|
632 |
-
"output_type": "stream",
|
633 |
-
"text": [
|
634 |
-
"Step 410, Loss: 0.38174617290496826, L2 norm: 28.943510055541992, avg rank: 1.2592592239379883\n"
|
635 |
-
]
|
636 |
-
},
|
637 |
-
{
|
638 |
-
"name": "stderr",
|
639 |
-
"output_type": "stream",
|
640 |
-
"text": [
|
641 |
-
" 8%|▊ | 422/5000 [01:01<12:41, 6.01it/s]"
|
642 |
-
]
|
643 |
-
},
|
644 |
-
{
|
645 |
-
"name": "stdout",
|
646 |
-
"output_type": "stream",
|
647 |
-
"text": [
|
648 |
-
"Step 420, Loss: 0.29940587282180786, L2 norm: 29.132537841796875, avg rank: 1.1111111640930176\n"
|
649 |
-
]
|
650 |
-
},
|
651 |
-
{
|
652 |
-
"name": "stderr",
|
653 |
-
"output_type": "stream",
|
654 |
-
"text": [
|
655 |
-
" 9%|▊ | 432/5000 [01:02<10:03, 7.57it/s]"
|
656 |
-
]
|
657 |
-
},
|
658 |
-
{
|
659 |
-
"name": "stdout",
|
660 |
-
"output_type": "stream",
|
661 |
-
"text": [
|
662 |
-
"Step 430, Loss: 0.23488281667232513, L2 norm: 29.293710708618164, avg rank: 1.0740740299224854\n"
|
663 |
-
]
|
664 |
-
},
|
665 |
-
{
|
666 |
-
"name": "stderr",
|
667 |
-
"output_type": "stream",
|
668 |
-
"text": [
|
669 |
-
" 9%|▉ | 442/5000 [01:03<09:35, 7.91it/s]"
|
670 |
-
]
|
671 |
-
},
|
672 |
-
{
|
673 |
-
"name": "stdout",
|
674 |
-
"output_type": "stream",
|
675 |
-
"text": [
|
676 |
-
"Step 440, Loss: 0.19830304384231567, L2 norm: 29.436723709106445, avg rank: 1.0370370149612427\n"
|
677 |
-
]
|
678 |
-
},
|
679 |
-
{
|
680 |
-
"name": "stderr",
|
681 |
-
"output_type": "stream",
|
682 |
-
"text": [
|
683 |
-
" 9%|▉ | 452/5000 [01:05<10:41, 7.09it/s]"
|
684 |
-
]
|
685 |
-
},
|
686 |
-
{
|
687 |
-
"name": "stdout",
|
688 |
-
"output_type": "stream",
|
689 |
-
"text": [
|
690 |
-
"Step 450, Loss: 0.17161428928375244, L2 norm: 29.54309844970703, avg rank: 1.0370370149612427\n"
|
691 |
-
]
|
692 |
-
},
|
693 |
-
{
|
694 |
-
"name": "stderr",
|
695 |
-
"output_type": "stream",
|
696 |
-
"text": [
|
697 |
-
" 9%|▉ | 454/5000 [01:05<10:54, 6.94it/s]"
|
698 |
-
]
|
699 |
-
},
|
700 |
-
{
|
701 |
-
"name": "stdout",
|
702 |
-
"output_type": "stream",
|
703 |
-
"text": [
|
704 |
-
"Perfect ranks achieved at step 454, stopping optimization.\n"
|
705 |
-
]
|
706 |
-
},
|
707 |
-
{
|
708 |
-
"name": "stderr",
|
709 |
-
"output_type": "stream",
|
710 |
-
"text": [
|
711 |
-
"\n"
|
712 |
-
]
|
713 |
-
}
|
714 |
-
],
|
715 |
"source": [
|
716 |
"import transformers\n",
|
717 |
"import torch\n",
|
@@ -738,30 +50,28 @@
|
|
738 |
"\n",
|
739 |
"embeddings = model.transformer.wte(tokens[\"input_ids\"]).detach()\n",
|
740 |
"\n",
|
741 |
-
"# We'll use a
|
742 |
-
"# Generate a
|
743 |
"# We'll optimize these hidden states to maximize the likelihood of the text that follows it\n",
|
744 |
"# past_key_values (Tuple[Tuple[torch.Tensor]] of length config.n_layers) — Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see past_key_values output below). Can be used to speed up sequential decoding. The input_ids which have their past given to this model should not be passed as input_ids as they have already been computed.\n",
|
745 |
"# Shape: past_key_values (Tuple[Tuple[torch.Tensor]] of length config.n_layers)\n",
|
746 |
"# with each tuple having 2 tensors of shape (batch_size, num_heads, sequence_length, embed_size_per_head))\n",
|
747 |
"# The precision of the hidden states should be fp16\n",
|
748 |
-
"
|
749 |
" # One tensor of shape (1, 1, embed_size_per_head)\n",
|
750 |
" torch.randn(\n",
|
751 |
" 1,\n",
|
752 |
" 4,\n",
|
753 |
" model.config.n_embd,\n",
|
754 |
" requires_grad=True,\n",
|
755 |
-
" # dtype=torch.float16,\n",
|
756 |
" )\n",
|
757 |
-
" # for _ in range(model.config.n_layer)\n",
|
758 |
")\n",
|
759 |
"\n",
|
760 |
-
"# Copy the
|
761 |
-
"
|
762 |
"\n",
|
763 |
"# Define the optimizer\n",
|
764 |
-
"optimizer = torch.optim.Adam([
|
765 |
"\n",
|
766 |
"# Define the loss function\n",
|
767 |
"loss_fn = torch.nn.CrossEntropyLoss()\n",
|
@@ -775,8 +85,8 @@
|
|
775 |
"# Freeze the model parameters\n",
|
776 |
"model.eval()\n",
|
777 |
"\n",
|
778 |
-
"# Check that the
|
779 |
-
"for key in
|
780 |
" assert key.requires_grad\n",
|
781 |
"\n",
|
782 |
"# Disable gradient computation for the model\n",
|
@@ -784,20 +94,20 @@
|
|
784 |
" param.requires_grad = False\n",
|
785 |
"\n",
|
786 |
"best_loss = float(\"inf\")\n",
|
787 |
-
"
|
788 |
"\n",
|
789 |
-
"# Optimize the
|
790 |
"for step in tqdm.tqdm(range(n_steps)):\n",
|
791 |
" # Zero the gradients\n",
|
792 |
" optimizer.zero_grad()\n",
|
793 |
"\n",
|
794 |
-
" # Add the optimizable
|
795 |
-
" expanded_prompt = torch.cat([
|
796 |
"\n",
|
797 |
" # Generate the logits for the text\n",
|
798 |
" logits: torch.Tensor = model(inputs_embeds=expanded_prompt).logits\n",
|
799 |
"\n",
|
800 |
-
" probs = torch.softmax(logits[:,
|
801 |
"\n",
|
802 |
" # Compute the ranks of the input IDs, i.e. how many tokens would have been more likely than the correct one (the label, the input IDs)\n",
|
803 |
" \n",
|
@@ -806,21 +116,21 @@
|
|
806 |
"\n",
|
807 |
" # Compute the loss\n",
|
808 |
" loss = loss_fn(\n",
|
809 |
-
" logits[:,
|
810 |
" tokens[\"input_ids\"].reshape(-1),\n",
|
811 |
" )\n",
|
812 |
"\n",
|
813 |
" # Backpropagate the gradients\n",
|
814 |
" loss.backward()\n",
|
815 |
"\n",
|
816 |
-
" # Optimize the
|
817 |
" optimizer.step()\n",
|
818 |
"\n",
|
819 |
" if step % 10 == 0:\n",
|
820 |
-
" # Get the L2 norm of the difference between the original and optimized
|
821 |
" l2_norm = sum(\n",
|
822 |
" torch.norm(optimized - original, p=2)\n",
|
823 |
-
" for optimized, original in zip(
|
824 |
" )\n",
|
825 |
" print(\n",
|
826 |
" f\"Step {step}, Loss: {loss.item()}, L2 norm: {l2_norm.item()}, avg rank: {ranks.float().mean().item()}\"\n",
|
@@ -829,7 +139,7 @@
|
|
829 |
" # Early stopping with patience\n",
|
830 |
" if loss.item() < best_loss and loss.item() > epsilon:\n",
|
831 |
" best_loss = loss.item()\n",
|
832 |
-
"
|
833 |
" patience_counter = 0\n",
|
834 |
" else:\n",
|
835 |
" patience_counter += 1\n",
|
@@ -846,187 +156,21 @@
|
|
846 |
},
|
847 |
{
|
848 |
"cell_type": "code",
|
849 |
-
"execution_count":
|
850 |
"id": "cc9a6a2f",
|
851 |
"metadata": {},
|
852 |
"outputs": [],
|
853 |
"source": [
|
854 |
-
"# Save the best
|
855 |
-
"torch.save(
|
856 |
]
|
857 |
},
|
858 |
{
|
859 |
"cell_type": "code",
|
860 |
-
"execution_count":
|
861 |
"id": "3186747d",
|
862 |
"metadata": {},
|
863 |
-
"outputs": [
|
864 |
-
{
|
865 |
-
"name": "stderr",
|
866 |
-
"output_type": "stream",
|
867 |
-
"text": [
|
868 |
-
"100%|██████████| 150/150 [00:10<00:00, 14.83it/s]\n"
|
869 |
-
]
|
870 |
-
},
|
871 |
-
{
|
872 |
-
"name": "stdout",
|
873 |
-
"output_type": "stream",
|
874 |
-
"text": [
|
875 |
-
"Reference: Hey this is a test text of me, Aldan, writing some random text in this text box. Will this work? Perhaps!\n",
|
876 |
-
"Generated: Hey this is a test text of me, Aldan, writing some random text in this text box. Will this work? Perhaps! I'll have to test this out. Maybe. I'll have to test this out some more. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe\n",
|
877 |
-
"'Hey':\tRank 0.00, probability: 71.02%, Reference: 'Hey'\n",
|
878 |
-
"' this':\tRank 0.00, probability: 87.04%, Reference: ' this'\n",
|
879 |
-
"' is':\tRank 0.00, probability: 99.65%, Reference: ' is'\n",
|
880 |
-
"' a':\tRank 0.00, probability: 96.23%, Reference: ' a'\n",
|
881 |
-
"' test':\tRank 0.00, probability: 90.31%, Reference: ' test'\n",
|
882 |
-
"' text':\tRank 0.00, probability: 95.40%, Reference: ' text'\n",
|
883 |
-
"' of':\tRank 0.00, probability: 62.63%, Reference: ' of'\n",
|
884 |
-
"' me':\tRank 0.00, probability: 93.07%, Reference: ' me'\n",
|
885 |
-
"',':\tRank 0.00, probability: 95.56%, Reference: ','\n",
|
886 |
-
"' Ald':\tRank 0.00, probability: 71.15%, Reference: ' Ald'\n",
|
887 |
-
"'an':\tRank 0.00, probability: 84.45%, Reference: 'an'\n",
|
888 |
-
"',':\tRank 0.00, probability: 97.80%, Reference: ','\n",
|
889 |
-
"' writing':\tRank 0.00, probability: 95.00%, Reference: ' writing'\n",
|
890 |
-
"' some':\tRank 0.00, probability: 88.69%, Reference: ' some'\n",
|
891 |
-
"' random':\tRank 0.00, probability: 97.07%, Reference: ' random'\n",
|
892 |
-
"' text':\tRank 0.00, probability: 89.14%, Reference: ' text'\n",
|
893 |
-
"' in':\tRank 0.00, probability: 90.19%, Reference: ' in'\n",
|
894 |
-
"' this':\tRank 0.00, probability: 98.12%, Reference: ' this'\n",
|
895 |
-
"' text':\tRank 0.00, probability: 96.47%, Reference: ' text'\n",
|
896 |
-
"' box':\tRank 0.00, probability: 95.55%, Reference: ' box'\n",
|
897 |
-
"'.':\tRank 0.00, probability: 97.38%, Reference: '.'\n",
|
898 |
-
"' Will':\tRank 0.00, probability: 90.65%, Reference: ' Will'\n",
|
899 |
-
"' this':\tRank 0.00, probability: 93.50%, Reference: ' this'\n",
|
900 |
-
"' work':\tRank 0.00, probability: 49.78%, Reference: ' work'\n",
|
901 |
-
"'?':\tRank 0.00, probability: 96.33%, Reference: '?'\n",
|
902 |
-
"' Perhaps':\tRank 0.00, probability: 76.97%, Reference: ' Perhaps'\n",
|
903 |
-
"'!':\tRank 0.00, probability: 42.45%, Reference: '!'\n",
|
904 |
-
"' I':\tRank 0.00, probability: 11.67%, Reference: 'N/A'\n",
|
905 |
-
"''ll':\tRank 0.00, probability: 10.19%, Reference: 'N/A'\n",
|
906 |
-
"' have':\tRank 0.00, probability: 53.08%, Reference: 'N/A'\n",
|
907 |
-
"' to':\tRank 0.00, probability: 77.45%, Reference: 'N/A'\n",
|
908 |
-
"' test':\tRank 0.00, probability: 8.06%, Reference: 'N/A'\n",
|
909 |
-
"' this':\tRank 0.00, probability: 40.56%, Reference: 'N/A'\n",
|
910 |
-
"' out':\tRank 0.00, probability: 51.98%, Reference: 'N/A'\n",
|
911 |
-
"'.':\tRank 0.00, probability: 26.31%, Reference: 'N/A'\n",
|
912 |
-
"' Maybe':\tRank 0.00, probability: 27.11%, Reference: 'N/A'\n",
|
913 |
-
"'.':\tRank 0.00, probability: 22.16%, Reference: 'N/A'\n",
|
914 |
-
"' I':\tRank 0.00, probability: 20.00%, Reference: 'N/A'\n",
|
915 |
-
"''ll':\tRank 0.00, probability: 21.35%, Reference: 'N/A'\n",
|
916 |
-
"' have':\tRank 0.00, probability: 59.45%, Reference: 'N/A'\n",
|
917 |
-
"' to':\tRank 0.00, probability: 74.45%, Reference: 'N/A'\n",
|
918 |
-
"' test':\tRank 0.00, probability: 17.91%, Reference: 'N/A'\n",
|
919 |
-
"' this':\tRank 0.00, probability: 70.72%, Reference: 'N/A'\n",
|
920 |
-
"' out':\tRank 0.00, probability: 87.33%, Reference: 'N/A'\n",
|
921 |
-
"' some':\tRank 0.00, probability: 28.23%, Reference: 'N/A'\n",
|
922 |
-
"' more':\tRank 0.00, probability: 26.20%, Reference: 'N/A'\n",
|
923 |
-
"'.':\tRank 0.00, probability: 22.73%, Reference: 'N/A'\n",
|
924 |
-
"' Maybe':\tRank 0.00, probability: 36.98%, Reference: 'N/A'\n",
|
925 |
-
"'.':\tRank 0.00, probability: 54.91%, Reference: 'N/A'\n",
|
926 |
-
"' Maybe':\tRank 0.00, probability: 64.39%, Reference: 'N/A'\n",
|
927 |
-
"'.':\tRank 0.00, probability: 65.85%, Reference: 'N/A'\n",
|
928 |
-
"' Maybe':\tRank 0.00, probability: 65.80%, Reference: 'N/A'\n",
|
929 |
-
"'.':\tRank 0.00, probability: 51.26%, Reference: 'N/A'\n",
|
930 |
-
"' Maybe':\tRank 0.00, probability: 60.87%, Reference: 'N/A'\n",
|
931 |
-
"'.':\tRank 0.00, probability: 35.56%, Reference: 'N/A'\n",
|
932 |
-
"' Maybe':\tRank 0.00, probability: 48.26%, Reference: 'N/A'\n",
|
933 |
-
"'.':\tRank 0.00, probability: 28.23%, Reference: 'N/A'\n",
|
934 |
-
"' Maybe':\tRank 0.00, probability: 43.46%, Reference: 'N/A'\n",
|
935 |
-
"'.':\tRank 0.00, probability: 23.53%, Reference: 'N/A'\n",
|
936 |
-
"' Maybe':\tRank 0.00, probability: 42.47%, Reference: 'N/A'\n",
|
937 |
-
"'.':\tRank 0.00, probability: 21.56%, Reference: 'N/A'\n",
|
938 |
-
"' Maybe':\tRank 0.00, probability: 44.99%, Reference: 'N/A'\n",
|
939 |
-
"'.':\tRank 0.00, probability: 22.61%, Reference: 'N/A'\n",
|
940 |
-
"' Maybe':\tRank 0.00, probability: 48.24%, Reference: 'N/A'\n",
|
941 |
-
"'.':\tRank 0.00, probability: 19.47%, Reference: 'N/A'\n",
|
942 |
-
"' Maybe':\tRank 0.00, probability: 51.23%, Reference: 'N/A'\n",
|
943 |
-
"'.':\tRank 0.00, probability: 19.13%, Reference: 'N/A'\n",
|
944 |
-
"' Maybe':\tRank 0.00, probability: 54.27%, Reference: 'N/A'\n",
|
945 |
-
"'.':\tRank 0.00, probability: 20.28%, Reference: 'N/A'\n",
|
946 |
-
"' Maybe':\tRank 0.00, probability: 56.09%, Reference: 'N/A'\n",
|
947 |
-
"'.':\tRank 0.00, probability: 23.08%, Reference: 'N/A'\n",
|
948 |
-
"' Maybe':\tRank 0.00, probability: 60.23%, Reference: 'N/A'\n",
|
949 |
-
"'.':\tRank 0.00, probability: 25.86%, Reference: 'N/A'\n",
|
950 |
-
"' Maybe':\tRank 0.00, probability: 63.49%, Reference: 'N/A'\n",
|
951 |
-
"'.':\tRank 0.00, probability: 28.66%, Reference: 'N/A'\n",
|
952 |
-
"' Maybe':\tRank 0.00, probability: 65.52%, Reference: 'N/A'\n",
|
953 |
-
"'.':\tRank 0.00, probability: 29.81%, Reference: 'N/A'\n",
|
954 |
-
"' Maybe':\tRank 0.00, probability: 67.58%, Reference: 'N/A'\n",
|
955 |
-
"'.':\tRank 0.00, probability: 32.28%, Reference: 'N/A'\n",
|
956 |
-
"' Maybe':\tRank 0.00, probability: 69.40%, Reference: 'N/A'\n",
|
957 |
-
"'.':\tRank 0.00, probability: 33.50%, Reference: 'N/A'\n",
|
958 |
-
"' Maybe':\tRank 0.00, probability: 70.15%, Reference: 'N/A'\n",
|
959 |
-
"'.':\tRank 0.00, probability: 34.01%, Reference: 'N/A'\n",
|
960 |
-
"' Maybe':\tRank 0.00, probability: 71.20%, Reference: 'N/A'\n",
|
961 |
-
"'.':\tRank 0.00, probability: 36.83%, Reference: 'N/A'\n",
|
962 |
-
"' Maybe':\tRank 0.00, probability: 71.29%, Reference: 'N/A'\n",
|
963 |
-
"'.':\tRank 0.00, probability: 38.54%, Reference: 'N/A'\n",
|
964 |
-
"' Maybe':\tRank 0.00, probability: 72.09%, Reference: 'N/A'\n",
|
965 |
-
"'.':\tRank 0.00, probability: 40.10%, Reference: 'N/A'\n",
|
966 |
-
"' Maybe':\tRank 0.00, probability: 71.66%, Reference: 'N/A'\n",
|
967 |
-
"'.':\tRank 0.00, probability: 41.78%, Reference: 'N/A'\n",
|
968 |
-
"' Maybe':\tRank 0.00, probability: 72.04%, Reference: 'N/A'\n",
|
969 |
-
"'.':\tRank 0.00, probability: 42.55%, Reference: 'N/A'\n",
|
970 |
-
"' Maybe':\tRank 0.00, probability: 71.18%, Reference: 'N/A'\n",
|
971 |
-
"'.':\tRank 0.00, probability: 44.33%, Reference: 'N/A'\n",
|
972 |
-
"' Maybe':\tRank 0.00, probability: 70.62%, Reference: 'N/A'\n",
|
973 |
-
"'.':\tRank 0.00, probability: 44.84%, Reference: 'N/A'\n",
|
974 |
-
"' Maybe':\tRank 0.00, probability: 70.70%, Reference: 'N/A'\n",
|
975 |
-
"'.':\tRank 0.00, probability: 44.89%, Reference: 'N/A'\n",
|
976 |
-
"' Maybe':\tRank 0.00, probability: 71.80%, Reference: 'N/A'\n",
|
977 |
-
"'.':\tRank 0.00, probability: 46.26%, Reference: 'N/A'\n",
|
978 |
-
"' Maybe':\tRank 0.00, probability: 71.88%, Reference: 'N/A'\n",
|
979 |
-
"'.':\tRank 0.00, probability: 46.11%, Reference: 'N/A'\n",
|
980 |
-
"' Maybe':\tRank 0.00, probability: 72.24%, Reference: 'N/A'\n",
|
981 |
-
"'.':\tRank 0.00, probability: 44.98%, Reference: 'N/A'\n",
|
982 |
-
"' Maybe':\tRank 0.00, probability: 73.18%, Reference: 'N/A'\n",
|
983 |
-
"'.':\tRank 0.00, probability: 45.08%, Reference: 'N/A'\n",
|
984 |
-
"' Maybe':\tRank 0.00, probability: 75.74%, Reference: 'N/A'\n",
|
985 |
-
"'.':\tRank 0.00, probability: 45.33%, Reference: 'N/A'\n",
|
986 |
-
"' Maybe':\tRank 0.00, probability: 76.68%, Reference: 'N/A'\n",
|
987 |
-
"'.':\tRank 0.00, probability: 46.90%, Reference: 'N/A'\n",
|
988 |
-
"' Maybe':\tRank 0.00, probability: 78.19%, Reference: 'N/A'\n",
|
989 |
-
"'.':\tRank 0.00, probability: 46.31%, Reference: 'N/A'\n",
|
990 |
-
"' Maybe':\tRank 0.00, probability: 78.19%, Reference: 'N/A'\n",
|
991 |
-
"'.':\tRank 0.00, probability: 48.19%, Reference: 'N/A'\n",
|
992 |
-
"' Maybe':\tRank 0.00, probability: 79.38%, Reference: 'N/A'\n",
|
993 |
-
"'.':\tRank 0.00, probability: 48.56%, Reference: 'N/A'\n",
|
994 |
-
"' Maybe':\tRank 0.00, probability: 80.29%, Reference: 'N/A'\n",
|
995 |
-
"'.':\tRank 0.00, probability: 49.49%, Reference: 'N/A'\n",
|
996 |
-
"' Maybe':\tRank 0.00, probability: 79.51%, Reference: 'N/A'\n",
|
997 |
-
"'.':\tRank 0.00, probability: 51.13%, Reference: 'N/A'\n",
|
998 |
-
"' Maybe':\tRank 0.00, probability: 81.70%, Reference: 'N/A'\n",
|
999 |
-
"'.':\tRank 0.00, probability: 52.29%, Reference: 'N/A'\n",
|
1000 |
-
"' Maybe':\tRank 0.00, probability: 80.87%, Reference: 'N/A'\n",
|
1001 |
-
"'.':\tRank 0.00, probability: 52.84%, Reference: 'N/A'\n",
|
1002 |
-
"' Maybe':\tRank 0.00, probability: 81.61%, Reference: 'N/A'\n",
|
1003 |
-
"'.':\tRank 0.00, probability: 54.15%, Reference: 'N/A'\n",
|
1004 |
-
"' Maybe':\tRank 0.00, probability: 81.67%, Reference: 'N/A'\n",
|
1005 |
-
"'.':\tRank 0.00, probability: 55.17%, Reference: 'N/A'\n",
|
1006 |
-
"' Maybe':\tRank 0.00, probability: 82.23%, Reference: 'N/A'\n",
|
1007 |
-
"'.':\tRank 0.00, probability: 55.47%, Reference: 'N/A'\n",
|
1008 |
-
"' Maybe':\tRank 0.00, probability: 82.99%, Reference: 'N/A'\n",
|
1009 |
-
"'.':\tRank 0.00, probability: 57.04%, Reference: 'N/A'\n",
|
1010 |
-
"' Maybe':\tRank 0.00, probability: 83.32%, Reference: 'N/A'\n",
|
1011 |
-
"'.':\tRank 0.00, probability: 58.06%, Reference: 'N/A'\n",
|
1012 |
-
"' Maybe':\tRank 0.00, probability: 84.47%, Reference: 'N/A'\n",
|
1013 |
-
"'.':\tRank 0.00, probability: 59.85%, Reference: 'N/A'\n",
|
1014 |
-
"' Maybe':\tRank 0.00, probability: 83.82%, Reference: 'N/A'\n",
|
1015 |
-
"'.':\tRank 0.00, probability: 58.62%, Reference: 'N/A'\n",
|
1016 |
-
"' Maybe':\tRank 0.00, probability: 85.14%, Reference: 'N/A'\n",
|
1017 |
-
"'.':\tRank 0.00, probability: 60.26%, Reference: 'N/A'\n",
|
1018 |
-
"' Maybe':\tRank 0.00, probability: 85.67%, Reference: 'N/A'\n",
|
1019 |
-
"'.':\tRank 0.00, probability: 62.56%, Reference: 'N/A'\n",
|
1020 |
-
"' Maybe':\tRank 0.00, probability: 86.05%, Reference: 'N/A'\n",
|
1021 |
-
"'.':\tRank 0.00, probability: 64.28%, Reference: 'N/A'\n",
|
1022 |
-
"' Maybe':\tRank 0.00, probability: 86.93%, Reference: 'N/A'\n",
|
1023 |
-
"'.':\tRank 0.00, probability: 67.84%, Reference: 'N/A'\n",
|
1024 |
-
"' Maybe':\tRank 0.00, probability: 87.52%, Reference: 'N/A'\n",
|
1025 |
-
"'.':\tRank 0.00, probability: 73.01%, Reference: 'N/A'\n",
|
1026 |
-
"' Maybe':\tRank 0.00, probability: 88.17%, Reference: 'N/A'\n"
|
1027 |
-
]
|
1028 |
-
}
|
1029 |
-
],
|
1030 |
"source": [
|
1031 |
"import os\n",
|
1032 |
"import transformers\n",
|
@@ -1042,24 +186,24 @@
|
|
1042 |
" \"openai-community/gpt2\"\n",
|
1043 |
" )\n",
|
1044 |
"\n",
|
1045 |
-
"# Load the best
|
1046 |
-
"
|
1047 |
"\n",
|
1048 |
"# Do a greedy decoding manually\n",
|
1049 |
-
"# Pass the optimized
|
1050 |
"# We can't use .generate() since we need to pass the inputs_embeds\n",
|
1051 |
"all_logits = []\n",
|
1052 |
"generated_tokens = []\n",
|
1053 |
"with torch.no_grad():\n",
|
1054 |
" for i in tqdm.tqdm(range(150)):\n",
|
1055 |
" # Generate the logits for the next token using what we've generated so far\n",
|
1056 |
-
" # If there are no generated tokens yet, just take the
|
1057 |
" if len(generated_tokens) == 0:\n",
|
1058 |
-
" expanded_prompt =
|
1059 |
" else:\n",
|
1060 |
" expanded_prompt = torch.cat(\n",
|
1061 |
" [\n",
|
1062 |
-
"
|
1063 |
" model.transformer.wte(\n",
|
1064 |
" torch.tensor(generated_tokens).unsqueeze(0)\n",
|
1065 |
" ).detach(),\n",
|
@@ -1069,9 +213,9 @@
|
|
1069 |
"\n",
|
1070 |
" assert expanded_prompt.shape == (\n",
|
1071 |
" 1,\n",
|
1072 |
-
"
|
1073 |
" model.config.n_embd,\n",
|
1074 |
-
" ), f\"Got size {expanded_prompt.shape} instead of (1, {
|
1075 |
"\n",
|
1076 |
" # Generate the logits for the text\n",
|
1077 |
" logits: torch.Tensor = model(inputs_embeds=expanded_prompt).logits\n",
|
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
"id": "d1daab37",
|
7 |
"metadata": {},
|
8 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
"source": [
|
10 |
"import ipywidgets\n",
|
11 |
"\n",
|
|
|
20 |
},
|
21 |
{
|
22 |
"cell_type": "code",
|
23 |
+
"execution_count": null,
|
24 |
"id": "09b3c097",
|
25 |
"metadata": {},
|
26 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
"source": [
|
28 |
"import transformers\n",
|
29 |
"import torch\n",
|
|
|
50 |
"\n",
|
51 |
"embeddings = model.transformer.wte(tokens[\"input_ids\"]).detach()\n",
|
52 |
"\n",
|
53 |
+
"# We'll use a Simon Says prompt - a special token and its hidden states for the first token that we'll use to condition the model. We'll optimize the hidden states of this token to maximize the likelihood of the text that follows it.\n",
|
54 |
+
"# Generate a Simon Says prompt by creating random hidden states for the first token\n",
|
55 |
"# We'll optimize these hidden states to maximize the likelihood of the text that follows it\n",
|
56 |
"# past_key_values (Tuple[Tuple[torch.Tensor]] of length config.n_layers) — Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see past_key_values output below). Can be used to speed up sequential decoding. The input_ids which have their past given to this model should not be passed as input_ids as they have already been computed.\n",
|
57 |
"# Shape: past_key_values (Tuple[Tuple[torch.Tensor]] of length config.n_layers)\n",
|
58 |
"# with each tuple having 2 tensors of shape (batch_size, num_heads, sequence_length, embed_size_per_head))\n",
|
59 |
"# The precision of the hidden states should be fp16\n",
|
60 |
+
"simon_says_prompt = (\n",
|
61 |
" # One tensor of shape (1, 1, embed_size_per_head)\n",
|
62 |
" torch.randn(\n",
|
63 |
" 1,\n",
|
64 |
" 4,\n",
|
65 |
" model.config.n_embd,\n",
|
66 |
" requires_grad=True,\n",
|
|
|
67 |
" )\n",
|
|
|
68 |
")\n",
|
69 |
"\n",
|
70 |
+
"# Copy the Simon Says prompts since we'll optimize them and they'll change\n",
|
71 |
+
"original_simon_says_prompt = tuple(key.detach().clone() for key in simon_says_prompt)\n",
|
72 |
"\n",
|
73 |
"# Define the optimizer\n",
|
74 |
+
"optimizer = torch.optim.Adam([simon_says_prompt], lr=1e-2)\n",
|
75 |
"\n",
|
76 |
"# Define the loss function\n",
|
77 |
"loss_fn = torch.nn.CrossEntropyLoss()\n",
|
|
|
85 |
"# Freeze the model parameters\n",
|
86 |
"model.eval()\n",
|
87 |
"\n",
|
88 |
+
"# Check that the Simon Says prompt is optimizable (requires_grad=True)\n",
|
89 |
+
"for key in simon_says_prompt:\n",
|
90 |
" assert key.requires_grad\n",
|
91 |
"\n",
|
92 |
"# Disable gradient computation for the model\n",
|
|
|
94 |
" param.requires_grad = False\n",
|
95 |
"\n",
|
96 |
"best_loss = float(\"inf\")\n",
|
97 |
+
"best_simon_says_prompt = None\n",
|
98 |
"\n",
|
99 |
+
"# Optimize the Simon Says prompt\n",
|
100 |
"for step in tqdm.tqdm(range(n_steps)):\n",
|
101 |
" # Zero the gradients\n",
|
102 |
" optimizer.zero_grad()\n",
|
103 |
"\n",
|
104 |
+
" # Add the optimizable Simon Says prompt to the embeddings\n",
|
105 |
+
" expanded_prompt = torch.cat([simon_says_prompt, embeddings], dim=1)\n",
|
106 |
"\n",
|
107 |
" # Generate the logits for the text\n",
|
108 |
" logits: torch.Tensor = model(inputs_embeds=expanded_prompt).logits\n",
|
109 |
"\n",
|
110 |
+
" probs = torch.softmax(logits[:, simon_says_prompt.size(-2) - 1 :-1], dim=-1)\n",
|
111 |
"\n",
|
112 |
" # Compute the ranks of the input IDs, i.e. how many tokens would have been more likely than the correct one (the label, the input IDs)\n",
|
113 |
" \n",
|
|
|
116 |
"\n",
|
117 |
" # Compute the loss\n",
|
118 |
" loss = loss_fn(\n",
|
119 |
+
" logits[:, simon_says_prompt.size(-2) - 1 :-1].reshape(-1, logits.size(-1)),\n",
|
120 |
" tokens[\"input_ids\"].reshape(-1),\n",
|
121 |
" )\n",
|
122 |
"\n",
|
123 |
" # Backpropagate the gradients\n",
|
124 |
" loss.backward()\n",
|
125 |
"\n",
|
126 |
+
" # Optimize the Simon Says prompt\n",
|
127 |
" optimizer.step()\n",
|
128 |
"\n",
|
129 |
" if step % 10 == 0:\n",
|
130 |
+
" # Get the L2 norm of the difference between the original and optimized Simon Says prompts\n",
|
131 |
" l2_norm = sum(\n",
|
132 |
" torch.norm(optimized - original, p=2)\n",
|
133 |
+
" for optimized, original in zip(simon_says_prompt, original_simon_says_prompt)\n",
|
134 |
" )\n",
|
135 |
" print(\n",
|
136 |
" f\"Step {step}, Loss: {loss.item()}, L2 norm: {l2_norm.item()}, avg rank: {ranks.float().mean().item()}\"\n",
|
|
|
139 |
" # Early stopping with patience\n",
|
140 |
" if loss.item() < best_loss and loss.item() > epsilon:\n",
|
141 |
" best_loss = loss.item()\n",
|
142 |
+
" best_simon_says_prompt = simon_says_prompt.detach().clone()\n",
|
143 |
" patience_counter = 0\n",
|
144 |
" else:\n",
|
145 |
" patience_counter += 1\n",
|
|
|
156 |
},
|
157 |
{
|
158 |
"cell_type": "code",
|
159 |
+
"execution_count": null,
|
160 |
"id": "cc9a6a2f",
|
161 |
"metadata": {},
|
162 |
"outputs": [],
|
163 |
"source": [
|
164 |
+
"# Save the best Simon Says prompt\n",
|
165 |
+
"torch.save(best_simon_says_prompt, \"best_simon_says_prompt.pt\")"
|
166 |
]
|
167 |
},
|
168 |
{
|
169 |
"cell_type": "code",
|
170 |
+
"execution_count": null,
|
171 |
"id": "3186747d",
|
172 |
"metadata": {},
|
173 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
"source": [
|
175 |
"import os\n",
|
176 |
"import transformers\n",
|
|
|
186 |
" \"openai-community/gpt2\"\n",
|
187 |
" )\n",
|
188 |
"\n",
|
189 |
+
"# Load the best Simon Says prompt from file\n",
|
190 |
+
"best_simon_says_prompt = torch.load(\"best_simon_says_prompt.pt\")\n",
|
191 |
"\n",
|
192 |
"# Do a greedy decoding manually\n",
|
193 |
+
"# Pass the optimized Simon Says prompt to the model\n",
|
194 |
"# We can't use .generate() since we need to pass the inputs_embeds\n",
|
195 |
"all_logits = []\n",
|
196 |
"generated_tokens = []\n",
|
197 |
"with torch.no_grad():\n",
|
198 |
" for i in tqdm.tqdm(range(150)):\n",
|
199 |
" # Generate the logits for the next token using what we've generated so far\n",
|
200 |
+
" # If there are no generated tokens yet, just take the Simon Says prompt\n",
|
201 |
" if len(generated_tokens) == 0:\n",
|
202 |
+
" expanded_prompt = best_simon_says_prompt\n",
|
203 |
" else:\n",
|
204 |
" expanded_prompt = torch.cat(\n",
|
205 |
" [\n",
|
206 |
+
" simon_says_prompt,\n",
|
207 |
" model.transformer.wte(\n",
|
208 |
" torch.tensor(generated_tokens).unsqueeze(0)\n",
|
209 |
" ).detach(),\n",
|
|
|
213 |
"\n",
|
214 |
" assert expanded_prompt.shape == (\n",
|
215 |
" 1,\n",
|
216 |
+
" best_simon_says_prompt.size(-2) + len(generated_tokens),\n",
|
217 |
" model.config.n_embd,\n",
|
218 |
+
" ), f\"Got size {expanded_prompt.shape} instead of (1, {best_simon_says_prompt.size(-2) + len(generated_tokens)}, {model.config.n_embd})\"\n",
|
219 |
"\n",
|
220 |
" # Generate the logits for the text\n",
|
221 |
" logits: torch.Tensor = model(inputs_embeds=expanded_prompt).logits\n",
|
requirements.txt
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
-
torch
|
2 |
-
transformers
|
|
|
3 |
tqdm
|
4 |
-
ipywidgets
|
|
|
|
1 |
+
torch==2.6.0
|
2 |
+
transformers==4.51.2
|
3 |
+
gradio==5.26.0
|
4 |
tqdm
|
5 |
+
ipywidgets
|
6 |
+
matplotlib
|