Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -303,18 +303,31 @@ def generate_music(prime, num_gen_tokens, num_gen_batches, model_temperature, mo
|
|
| 303 |
|
| 304 |
print("Generating...")
|
| 305 |
inp = torch.LongTensor([inputs] * num_gen_batches).cuda()
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
print("Done!")
|
| 319 |
print_sep()
|
| 320 |
return out.tolist()
|
|
@@ -577,7 +590,7 @@ with gr.Blocks() as demo:
|
|
| 577 |
num_prime_tokens = gr.Slider(16, 6656, value=6656, step=1, label="Number of prime tokens")
|
| 578 |
num_gen_tokens = gr.Slider(16, 1024, value=512, step=1, label="Number of tokens to generate")
|
| 579 |
model_temperature = gr.Slider(0.1, 1, value=0.9, step=0.01, label="Model temperature")
|
| 580 |
-
model_top_p = gr.Slider(0.1, 0
|
| 581 |
add_drums = gr.Checkbox(value=False, label="Add drums")
|
| 582 |
add_outro = gr.Checkbox(value=False, label="Add an outro")
|
| 583 |
generate_btn = gr.Button("Generate", variant="primary")
|
|
|
|
| 303 |
|
| 304 |
print("Generating...")
|
| 305 |
inp = torch.LongTensor([inputs] * num_gen_batches).cuda()
|
| 306 |
+
|
| 307 |
+
if model_top_p < 1:
|
| 308 |
+
with ctx:
|
| 309 |
+
out = model.generate(
|
| 310 |
+
inp,
|
| 311 |
+
num_gen_tokens,
|
| 312 |
+
filter_logits_fn=top_p,
|
| 313 |
+
filter_kwargs={'thres': model_top_p},
|
| 314 |
+
temperature=model_temperature,
|
| 315 |
+
eos_token=18818,
|
| 316 |
+
return_prime=False,
|
| 317 |
+
verbose=False
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
else:
|
| 321 |
+
with ctx:
|
| 322 |
+
out = model.generate(
|
| 323 |
+
inp,
|
| 324 |
+
num_gen_tokens,
|
| 325 |
+
temperature=model_temperature,
|
| 326 |
+
eos_token=18818,
|
| 327 |
+
return_prime=False,
|
| 328 |
+
verbose=False
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
print("Done!")
|
| 332 |
print_sep()
|
| 333 |
return out.tolist()
|
|
|
|
| 590 |
num_prime_tokens = gr.Slider(16, 6656, value=6656, step=1, label="Number of prime tokens")
|
| 591 |
num_gen_tokens = gr.Slider(16, 1024, value=512, step=1, label="Number of tokens to generate")
|
| 592 |
model_temperature = gr.Slider(0.1, 1, value=0.9, step=0.01, label="Model temperature")
|
| 593 |
+
model_top_p = gr.Slider(0.1, 1.0, value=0.96, step=0.01, label="Model sampling top p value")
|
| 594 |
add_drums = gr.Checkbox(value=False, label="Add drums")
|
| 595 |
add_outro = gr.Checkbox(value=False, label="Add an outro")
|
| 596 |
generate_btn = gr.Button("Generate", variant="primary")
|