Spaces:
Build error
Build error
no progress on batch
Browse files
app.py
CHANGED
|
@@ -49,6 +49,7 @@ def interrupt():
|
|
| 49 |
global INTERRUPTING
|
| 50 |
INTERRUPTING = True
|
| 51 |
|
|
|
|
| 52 |
def make_waveform(*args, **kwargs):
|
| 53 |
# Further remove some warnings.
|
| 54 |
be = time.time()
|
|
@@ -66,7 +67,7 @@ def load_model(version='melody'):
|
|
| 66 |
MODEL = MusicGen.get_pretrained(version)
|
| 67 |
|
| 68 |
|
| 69 |
-
def _do_predictions(texts, melodies, duration, **gen_kwargs):
|
| 70 |
MODEL.set_generation_params(duration=duration, **gen_kwargs)
|
| 71 |
print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
|
| 72 |
be = time.time()
|
|
@@ -89,10 +90,10 @@ def _do_predictions(texts, melodies, duration, **gen_kwargs):
|
|
| 89 |
descriptions=texts,
|
| 90 |
melody_wavs=processed_melodies,
|
| 91 |
melody_sample_rate=target_sr,
|
| 92 |
-
progress=
|
| 93 |
)
|
| 94 |
else:
|
| 95 |
-
outputs = MODEL.generate(texts, progress=
|
| 96 |
|
| 97 |
outputs = outputs.detach().cpu().float()
|
| 98 |
out_files = []
|
|
@@ -128,7 +129,7 @@ def predict_full(model, text, melody, duration, topk, topp, temperature, cfg_coe
|
|
| 128 |
MODEL.set_custom_progress_callback(_progress)
|
| 129 |
|
| 130 |
outs = _do_predictions(
|
| 131 |
-
[text], [melody], duration,
|
| 132 |
top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef)
|
| 133 |
return outs[0]
|
| 134 |
|
|
@@ -324,6 +325,8 @@ if __name__ == "__main__":
|
|
| 324 |
args = parser.parse_args()
|
| 325 |
|
| 326 |
launch_kwargs = {}
|
|
|
|
|
|
|
| 327 |
if args.username and args.password:
|
| 328 |
launch_kwargs['auth'] = (args.username, args.password)
|
| 329 |
if args.server_port:
|
|
|
|
| 49 |
global INTERRUPTING
|
| 50 |
INTERRUPTING = True
|
| 51 |
|
| 52 |
+
|
| 53 |
def make_waveform(*args, **kwargs):
|
| 54 |
# Further remove some warnings.
|
| 55 |
be = time.time()
|
|
|
|
| 67 |
MODEL = MusicGen.get_pretrained(version)
|
| 68 |
|
| 69 |
|
| 70 |
+
def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs):
|
| 71 |
MODEL.set_generation_params(duration=duration, **gen_kwargs)
|
| 72 |
print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
|
| 73 |
be = time.time()
|
|
|
|
| 90 |
descriptions=texts,
|
| 91 |
melody_wavs=processed_melodies,
|
| 92 |
melody_sample_rate=target_sr,
|
| 93 |
+
progress=progress,
|
| 94 |
)
|
| 95 |
else:
|
| 96 |
+
outputs = MODEL.generate(texts, progress=progress)
|
| 97 |
|
| 98 |
outputs = outputs.detach().cpu().float()
|
| 99 |
out_files = []
|
|
|
|
| 129 |
MODEL.set_custom_progress_callback(_progress)
|
| 130 |
|
| 131 |
outs = _do_predictions(
|
| 132 |
+
[text], [melody], duration, progress=True,
|
| 133 |
top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef)
|
| 134 |
return outs[0]
|
| 135 |
|
|
|
|
| 325 |
args = parser.parse_args()
|
| 326 |
|
| 327 |
launch_kwargs = {}
|
| 328 |
+
launch_kwargs['server_name'] = args.listen
|
| 329 |
+
|
| 330 |
if args.username and args.password:
|
| 331 |
launch_kwargs['auth'] = (args.username, args.password)
|
| 332 |
if args.server_port:
|