asigalov61 commited on
Commit
7da7c8e
·
verified ·
1 Parent(s): d2d98db

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -13
app.py CHANGED
@@ -303,18 +303,31 @@ def generate_music(prime, num_gen_tokens, num_gen_batches, model_temperature, mo
303
 
304
  print("Generating...")
305
  inp = torch.LongTensor([inputs] * num_gen_batches).cuda()
306
- with ctx:
307
- out = model.generate(
308
- inp,
309
- num_gen_tokens,
310
- filter_logits_fn=top_p,
311
- filter_kwargs={'thres': model_top_p},
312
- temperature=model_temperature,
313
- eos_token=18818,
314
- return_prime=False,
315
- verbose=False
316
- )
317
-
 
 
 
 
 
 
 
 
 
 
 
 
 
318
  print("Done!")
319
  print_sep()
320
  return out.tolist()
@@ -577,7 +590,7 @@ with gr.Blocks() as demo:
577
  num_prime_tokens = gr.Slider(16, 6656, value=6656, step=1, label="Number of prime tokens")
578
  num_gen_tokens = gr.Slider(16, 1024, value=512, step=1, label="Number of tokens to generate")
579
  model_temperature = gr.Slider(0.1, 1, value=0.9, step=0.01, label="Model temperature")
580
- model_top_p = gr.Slider(0.1, 0.99, value=0.96, step=0.01, label="Model sampling top p value")
581
  add_drums = gr.Checkbox(value=False, label="Add drums")
582
  add_outro = gr.Checkbox(value=False, label="Add an outro")
583
  generate_btn = gr.Button("Generate", variant="primary")
 
303
 
304
  print("Generating...")
305
  inp = torch.LongTensor([inputs] * num_gen_batches).cuda()
306
+
307
+ if model_top_p < 1:
308
+ with ctx:
309
+ out = model.generate(
310
+ inp,
311
+ num_gen_tokens,
312
+ filter_logits_fn=top_p,
313
+ filter_kwargs={'thres': model_top_p},
314
+ temperature=model_temperature,
315
+ eos_token=18818,
316
+ return_prime=False,
317
+ verbose=False
318
+ )
319
+
320
+ else:
321
+ with ctx:
322
+ out = model.generate(
323
+ inp,
324
+ num_gen_tokens,
325
+ temperature=model_temperature,
326
+ eos_token=18818,
327
+ return_prime=False,
328
+ verbose=False
329
+ )
330
+
331
  print("Done!")
332
  print_sep()
333
  return out.tolist()
 
590
  num_prime_tokens = gr.Slider(16, 6656, value=6656, step=1, label="Number of prime tokens")
591
  num_gen_tokens = gr.Slider(16, 1024, value=512, step=1, label="Number of tokens to generate")
592
  model_temperature = gr.Slider(0.1, 1, value=0.9, step=0.01, label="Model temperature")
593
+ model_top_p = gr.Slider(0.1, 1.0, value=0.96, step=0.01, label="Model sampling top p value")
594
  add_drums = gr.Checkbox(value=False, label="Add drums")
595
  add_outro = gr.Checkbox(value=False, label="Add an outro")
596
  generate_btn = gr.Button("Generate", variant="primary")