final changes
Browse files- app.py +1 -1
- tests/models/test_musicgen.py +2 -2
app.py
CHANGED
|
@@ -25,7 +25,7 @@ from audiocraft.models import MusicGen
|
|
| 25 |
|
| 26 |
MODEL = None # Last used model
|
| 27 |
IS_BATCHED = "facebook/MusicGen" in os.environ.get('SPACE_ID', '')
|
| 28 |
-
MAX_BATCH_SIZE =
|
| 29 |
BATCHED_DURATION = 15
|
| 30 |
INTERRUPTING = False
|
| 31 |
# We have to wrap subprocess call to clean a bit the log when using gr.make_waveform
|
|
|
|
| 25 |
|
| 26 |
MODEL = None # Last used model
|
| 27 |
IS_BATCHED = "facebook/MusicGen" in os.environ.get('SPACE_ID', '')
|
| 28 |
+
MAX_BATCH_SIZE = 12
|
| 29 |
BATCHED_DURATION = 15
|
| 30 |
INTERRUPTING = False
|
| 31 |
# We have to wrap subprocess call to clean a bit the log when using gr.make_waveform
|
tests/models/test_musicgen.py
CHANGED
|
@@ -13,7 +13,7 @@ from audiocraft.models import MusicGen
|
|
| 13 |
class TestSEANetModel:
|
| 14 |
def get_musicgen(self):
|
| 15 |
mg = MusicGen.get_pretrained(name='debug', device='cpu')
|
| 16 |
-
mg.set_generation_params(duration=2.0,
|
| 17 |
return mg
|
| 18 |
|
| 19 |
def test_base(self):
|
|
@@ -52,7 +52,7 @@ class TestSEANetModel:
|
|
| 52 |
def test_generate_long(self):
|
| 53 |
mg = self.get_musicgen()
|
| 54 |
mg.max_duration = 3.
|
| 55 |
-
mg.set_generation_params(duration=4.,
|
| 56 |
wav = mg.generate(
|
| 57 |
['youpi', 'lapin dort'])
|
| 58 |
assert list(wav.shape) == [2, 1, 32000 * 4]
|
|
|
|
| 13 |
class TestSEANetModel:
|
| 14 |
def get_musicgen(self):
|
| 15 |
mg = MusicGen.get_pretrained(name='debug', device='cpu')
|
| 16 |
+
mg.set_generation_params(duration=2.0, extend_stride=2.)
|
| 17 |
return mg
|
| 18 |
|
| 19 |
def test_base(self):
|
|
|
|
| 52 |
def test_generate_long(self):
|
| 53 |
mg = self.get_musicgen()
|
| 54 |
mg.max_duration = 3.
|
| 55 |
+
mg.set_generation_params(duration=4., extend_stride=2.)
|
| 56 |
wav = mg.generate(
|
| 57 |
['youpi', 'lapin dort'])
|
| 58 |
assert list(wav.shape) == [2, 1, 32000 * 4]
|