divyareddy commited on
Commit
22fee4d
Β·
verified Β·
1 Parent(s): 1365b6e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +236 -0
app.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import gradio as gr
4
+ import spaces
5
+ from queue import Queue
6
+ from threading import Thread
7
+ from typing import Optional
8
+ from transformers import MusicgenForConditionalGeneration, MusicgenProcessor, set_seed
9
+ from transformers.generation.streamers import BaseStreamer
10
+
11
+ model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
12
+ processor = MusicgenProcessor.from_pretrained("facebook/musicgen-small")
13
+
14
+ title = "9🌍MusicHub - Text to Music Stream Generator"
15
+ description = """ Facebook MusicGen-Small Model - Generate and stream music with model https://huggingface.co/facebook/musicgen-small """
16
+ article = """
17
+ ## How It Works:
18
+ MusicGen is an auto-regressive transformer-based model, meaning generates audio codes (tokens) in a causal fashion.
19
+ At each decoding step, the model generates a new set of audio codes, conditional on the text input and all previous audio codes. From the
20
+ frame rate of the [EnCodec model](https://huggingface.co/facebook/encodec_32khz) used to decode the generated codes to audio waveform.
21
+ """
22
+
23
+
24
+ class MusicgenStreamer(BaseStreamer):
25
+ def __init__(
26
+ self,
27
+ model: MusicgenForConditionalGeneration,
28
+ device: Optional[str] = None,
29
+ play_steps: Optional[int] = 10,
30
+ stride: Optional[int] = None,
31
+ timeout: Optional[float] = None,
32
+ ):
33
+ """
34
+ Streamer that stores playback-ready audio in a queue, to be used by a downstream application as an iterator. This is
35
+ useful for applications that benefit from acessing the generated audio in a non-blocking way (e.g. in an interactive
36
+ Gradio demo).
37
+ Parameters:
38
+ model (`MusicgenForConditionalGeneration`):
39
+ The MusicGen model used to generate the audio waveform.
40
+ device (`str`, *optional*):
41
+ The torch device on which to run the computation. If `None`, will default to the device of the model.
42
+ play_steps (`int`, *optional*, defaults to 10):
43
+ The number of generation steps with which to return the generated audio array. Using fewer steps will
44
+ mean the first chunk is ready faster, but will require more codec decoding steps overall. This value
45
+ should be tuned to your device and latency requirements.
46
+ stride (`int`, *optional*):
47
+ The window (stride) between adjacent audio samples. Using a stride between adjacent audio samples reduces
48
+ the hard boundary between them, giving smoother playback. If `None`, will default to a value equivalent to
49
+ play_steps // 6 in the audio space.
50
+ timeout (`int`, *optional*):
51
+ The timeout for the audio queue. If `None`, the queue will block indefinitely. Useful to handle exceptions
52
+ in `.generate()`, when it is called in a separate thread.
53
+ """
54
+ self.decoder = model.decoder
55
+ self.audio_encoder = model.audio_encoder
56
+ self.generation_config = model.generation_config
57
+ self.device = device if device is not None else model.device
58
+
59
+ # variables used in the streaming process
60
+ self.play_steps = play_steps
61
+ if stride is not None:
62
+ self.stride = stride
63
+ else:
64
+ hop_length = np.prod(self.audio_encoder.config.upsampling_ratios)
65
+ self.stride = hop_length * (play_steps - self.decoder.num_codebooks) // 6
66
+ self.token_cache = None
67
+ self.to_yield = 0
68
+
69
+ # varibles used in the thread process
70
+ self.audio_queue = Queue()
71
+ self.stop_signal = None
72
+ self.timeout = timeout
73
+
74
+ def apply_delay_pattern_mask(self, input_ids):
75
+ # build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to MusicGen)
76
+ _, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
77
+ input_ids[:, :1],
78
+ pad_token_id=self.generation_config.decoder_start_token_id,
79
+ max_length=input_ids.shape[-1],
80
+ )
81
+ # apply the pattern mask to the input ids
82
+ input_ids = self.decoder.apply_delay_pattern_mask(input_ids, decoder_delay_pattern_mask)
83
+
84
+ # revert the pattern delay mask by filtering the pad token id
85
+ input_ids = input_ids[input_ids != self.generation_config.pad_token_id].reshape(
86
+ 1, self.decoder.num_codebooks, -1
87
+ )
88
+
89
+ # append the frame dimension back to the audio codes
90
+ input_ids = input_ids[None, ...]
91
+
92
+ # send the input_ids to the correct device
93
+ input_ids = input_ids.to(self.audio_encoder.device)
94
+
95
+ output_values = self.audio_encoder.decode(
96
+ input_ids,
97
+ audio_scales=[None],
98
+ )
99
+ audio_values = output_values.audio_values[0, 0]
100
+ return audio_values.cpu().float().numpy()
101
+
102
+ def put(self, value):
103
+ batch_size = value.shape[0] // self.decoder.num_codebooks
104
+ if batch_size > 1:
105
+ raise ValueError("MusicgenStreamer only supports batch size 1")
106
+
107
+ if self.token_cache is None:
108
+ self.token_cache = value
109
+ else:
110
+ self.token_cache = torch.concatenate([self.token_cache, value[:, None]], dim=-1)
111
+
112
+ if self.token_cache.shape[-1] % self.play_steps == 0:
113
+ audio_values = self.apply_delay_pattern_mask(self.token_cache)
114
+ self.on_finalized_audio(audio_values[self.to_yield : -self.stride])
115
+ self.to_yield += len(audio_values) - self.to_yield - self.stride
116
+
117
+ def end(self):
118
+ """Flushes any remaining cache and appends the stop symbol."""
119
+ if self.token_cache is not None:
120
+ audio_values = self.apply_delay_pattern_mask(self.token_cache)
121
+ else:
122
+ audio_values = np.zeros(self.to_yield)
123
+
124
+ self.on_finalized_audio(audio_values[self.to_yield :], stream_end=True)
125
+
126
+ def on_finalized_audio(self, audio: np.ndarray, stream_end: bool = False):
127
+ """Put the new audio in the queue. If the stream is ending, also put a stop signal in the queue."""
128
+ self.audio_queue.put(audio, timeout=self.timeout)
129
+ if stream_end:
130
+ self.audio_queue.put(self.stop_signal, timeout=self.timeout)
131
+
132
+ def __iter__(self):
133
+ return self
134
+
135
+ def __next__(self):
136
+ value = self.audio_queue.get(timeout=self.timeout)
137
+ if not isinstance(value, np.ndarray) and value == self.stop_signal:
138
+ raise StopIteration()
139
+ else:
140
+ return value
141
+
142
+
143
+ sampling_rate = model.audio_encoder.config.sampling_rate
144
+ frame_rate = model.audio_encoder.config.frame_rate
145
+
146
+ target_dtype = np.int16
147
+ max_range = np.iinfo(target_dtype).max
148
+
149
+
150
+ @spaces.GPU
151
+ def generate_audio(text_prompt, audio_length_in_s=10.0, play_steps_in_s=2.0, seed=0):
152
+ max_new_tokens = int(frame_rate * audio_length_in_s)
153
+ play_steps = int(frame_rate * play_steps_in_s)
154
+
155
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
156
+ if device != model.device:
157
+ model.to(device)
158
+ if device == "cuda:0":
159
+ model.half()
160
+
161
+ inputs = processor(
162
+ text=text_prompt,
163
+ padding=True,
164
+ return_tensors="pt",
165
+ )
166
+
167
+ streamer = MusicgenStreamer(model, device=device, play_steps=play_steps)
168
+
169
+ generation_kwargs = dict(
170
+ **inputs.to(device),
171
+ streamer=streamer,
172
+ max_new_tokens=max_new_tokens,
173
+ )
174
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
175
+ thread.start()
176
+
177
+ set_seed(seed)
178
+ for new_audio in streamer:
179
+ print(f"Sample of length: {round(new_audio.shape[0] / sampling_rate, 2)} seconds")
180
+ new_audio = (new_audio * max_range).astype(np.int16)
181
+ yield (sampling_rate, new_audio)
182
+
183
+
184
+ demo = gr.Interface(
185
+ fn=generate_audio,
186
+ inputs=[
187
+ gr.Text(label="Prompt", value="80s pop track with synth and instrumentals"),
188
+ gr.Slider(10, 30, value=15, step=5, label="Audio length in seconds"),
189
+ gr.Slider(0.5, 2.5, value=0.5, step=0.5, label="Streaming interval in seconds", info="Lower = shorter chunks, lower latency, more codec steps"),
190
+ gr.Slider(0, 10, value=5, step=1, label="Seed for random generations"),
191
+ ],
192
+ outputs=[
193
+ gr.Audio(label="Generated Music", streaming=True, autoplay=True)
194
+ ],
195
+ #examples = [
196
+ # ["Country acoustic guitar fast line dance singer like Kenny Chesney and Garth brooks and Luke Combs and Chris Stapleton. bpm: 100", 30, 0.5, 5],
197
+ # ["Electronic Dance track with pulsating bass and high energy synths. bpm: 126", 30, 0.5, 5],
198
+ # ["Rap Beats with deep bass and snappy snares. bpm: 80", 30, 0.5, 5],
199
+ # ["Lo-Fi track with smooth beats and chill vibes. bpm: 100", 30, 0.5, 5],
200
+ # ["Global Groove track with international instruments and dance rhythms. bpm: 128", 30, 0.5, 5],
201
+ # ["Relaxing Meditation music with ambient pads and soothing melodies. bpm: 80", 30, 0.5, 5],
202
+ # ["Rave Dance track with hard-hitting beats and euphoric synths. bpm: 128", 30, 0.5, 5]
203
+ #],
204
+
205
+ examples = [
206
+ ["🌟 Dance/EDM track with vibrant synths and a driving beat. bpm: 128", 30, 0.5, 5],
207
+ ["πŸ’Ό Corporate theme with an upbeat tempo and motivational melodies. bpm: 120", 30, 0.5, 5],
208
+ ["🎸 Rock anthem with powerful guitar riffs and energetic drums. bpm: 140", 30, 0.5, 5],
209
+ ["🌊 Chill Out track with soothing ambient sounds and relaxed tempo. bpm: 90", 30, 0.5, 5],
210
+ ["🎀 Hip Hop beat with hard-hitting bass and catchy rhythms. bpm: 95", 30, 0.5, 5],
211
+ ["🎻 Orchestral piece with dramatic strings and grand composition. bpm: 70", 30, 0.5, 5],
212
+ ["πŸ•Ί Funk groove with groovy basslines and rhythmic guitars. bpm: 110", 30, 0.5, 5],
213
+ ["πŸ•ΉοΈ Video Game music with retro synths and catchy chiptune melodies. bpm: 130", 30, 0.5, 5],
214
+ ["🌾 Folk song with acoustic guitar and harmonious vocals. bpm: 85", 30, 0.5, 5],
215
+ ["πŸŒ™ Ambient soundscape with ethereal pads and calming tones. bpm: 60", 30, 0.5, 5],
216
+ ["🎷 Jazz tune with smooth saxophone and swinging rhythms. bpm: 120", 30, 0.5, 5],
217
+ ["πŸ‘Ά Kids music with cheerful melodies and playful instruments. bpm: 100", 30, 0.5, 5],
218
+ ["🌟 Pop hit with catchy hooks and upbeat rhythms. bpm: 115", 30, 0.5, 5],
219
+ ["🎬 Production music with versatile sounds for various media. bpm: 110", 30, 0.5, 5],
220
+ ["πŸ”Š Electronic piece with experimental sounds and unique textures. bpm: 125", 30, 0.5, 5],
221
+ ["🌴 Reggae rhythm with laid-back vibes and offbeat guitar chords. bpm: 75", 30, 0.5, 5],
222
+ ["πŸ’ƒ Dance track with infectious beats and lively energy. bpm: 130", 30, 0.5, 5],
223
+ ["🎀 R&B tune with smooth vocals and soulful grooves. bpm: 90", 30, 0.5, 5],
224
+ ["πŸŽ‰ Latin song with rhythmic percussion and fiery melodies. bpm: 105", 30, 0.5, 5],
225
+ ["🀠 Country track with twangy guitars and heartfelt storytelling. bpm: 85", 30, 0.5, 5],
226
+ ["🎢 Indian music with traditional instruments and intricate rhythms. bpm: 95", 30, 0.5, 5]
227
+ ],
228
+
229
+
230
+ title=title,
231
+ description=description,
232
+ article=article,
233
+ cache_examples=False,
234
+ )
235
+
236
+ demo.queue().launch()