soiz commited on
Commit
accc372
·
verified ·
1 Parent(s): b6a49a3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +159 -0
app.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from queue import Queue
2
+ from threading import Thread
3
+ from typing import Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+ from flask import Flask, request, jsonify, send_file
8
+ from transformers import MusicgenForConditionalGeneration, MusicgenProcessor, set_seed
9
+ from transformers.generation.streamers import BaseStreamer
10
+ import io
11
+ import soundfile as sf
12
+
13
+ # Load the model and processor
14
+ model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
15
+ processor = MusicgenProcessor.from_pretrained("facebook/musicgen-small")
16
+
17
+ class MusicgenStreamer(BaseStreamer):
18
+ def __init__(
19
+ self,
20
+ model: MusicgenForConditionalGeneration,
21
+ device: Optional[str] = None,
22
+ play_steps: Optional[int] = 10,
23
+ stride: Optional[int] = None,
24
+ timeout: Optional[float] = None,
25
+ ):
26
+ self.decoder = model.decoder
27
+ self.audio_encoder = model.audio_encoder
28
+ self.generation_config = model.generation_config
29
+ self.device = device if device is not None else model.device
30
+
31
+ self.play_steps = play_steps
32
+ if stride is not None:
33
+ self.stride = stride
34
+ else:
35
+ hop_length = np.prod(self.audio_encoder.config.upsampling_ratios)
36
+ self.stride = hop_length * (play_steps - self.decoder.num_codebooks) // 6
37
+ self.token_cache = None
38
+ self.to_yield = 0
39
+
40
+ self.audio_queue = Queue()
41
+ self.stop_signal = None
42
+ self.timeout = timeout
43
+
44
+ def apply_delay_pattern_mask(self, input_ids):
45
+ _, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
46
+ input_ids[:, :1],
47
+ pad_token_id=self.generation_config.decoder_start_token_id,
48
+ max_length=input_ids.shape[-1],
49
+ )
50
+ input_ids = self.decoder.apply_delay_pattern_mask(input_ids, decoder_delay_pattern_mask)
51
+
52
+ input_ids = input_ids[input_ids != self.generation_config.pad_token_id].reshape(
53
+ 1, self.decoder.num_codebooks, -1
54
+ )
55
+
56
+ input_ids = input_ids[None, ...]
57
+ input_ids = input_ids.to(self.audio_encoder.device)
58
+
59
+ output_values = self.audio_encoder.decode(
60
+ input_ids,
61
+ audio_scales=[None],
62
+ )
63
+ audio_values = output_values.audio_values[0, 0]
64
+ return audio_values.cpu().float().numpy()
65
+
66
+ def put(self, value):
67
+ batch_size = value.shape[0] // self.decoder.num_codebooks
68
+ if batch_size > 1:
69
+ raise ValueError("MusicgenStreamer only supports batch size 1")
70
+
71
+ if self.token_cache is None:
72
+ self.token_cache = value
73
+ else:
74
+ self.token_cache = torch.concatenate([self.token_cache, value[:, None]], dim=-1)
75
+
76
+ if self.token_cache.shape[-1] % self.play_steps == 0:
77
+ audio_values = self.apply_delay_pattern_mask(self.token_cache)
78
+ self.on_finalized_audio(audio_values[self.to_yield : -self.stride])
79
+ self.to_yield += len(audio_values) - self.to_yield - self.stride
80
+
81
+ def end(self):
82
+ if self.token_cache is not None:
83
+ audio_values = self.apply_delay_pattern_mask(self.token_cache)
84
+ else:
85
+ audio_values = np.zeros(self.to_yield)
86
+
87
+ self.on_finalized_audio(audio_values[self.to_yield :], stream_end=True)
88
+
89
+ def on_finalized_audio(self, audio: np.ndarray, stream_end: bool = False):
90
+ self.audio_queue.put(audio, timeout=self.timeout)
91
+ if stream_end:
92
+ self.audio_queue.put(self.stop_signal, timeout=self.timeout)
93
+
94
+ def __iter__(self):
95
+ return self
96
+
97
+ def __next__(self):
98
+ value = self.audio_queue.get(timeout=self.timeout)
99
+ if not isinstance(value, np.ndarray) and value == self.stop_signal:
100
+ raise StopIteration()
101
+ else:
102
+ return value
103
+
104
+
105
+ sampling_rate = model.audio_encoder.config.sampling_rate
106
+ frame_rate = model.audio_encoder.config.frame_rate
107
+
108
+ app = Flask(__name__)
109
+
110
+ @app.route('/generate_audio', methods=['POST'])
111
+ def generate_audio():
112
+ data = request.json
113
+ text_prompt = data.get('text_prompt', '80s pop track with synth and instrumentals')
114
+ audio_length_in_s = float(data.get('audio_length_in_s', 10.0))
115
+ play_steps_in_s = float(data.get('play_steps_in_s', 2.0))
116
+ seed = int(data.get('seed', 0))
117
+
118
+ max_new_tokens = int(frame_rate * audio_length_in_s)
119
+ play_steps = int(frame_rate * play_steps_in_s)
120
+
121
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
122
+ if device != model.device:
123
+ model.to(device)
124
+ if device == "cuda:0":
125
+ model.half()
126
+
127
+ inputs = processor(
128
+ text=text_prompt,
129
+ padding=True,
130
+ return_tensors="pt",
131
+ )
132
+
133
+ streamer = MusicgenStreamer(model, device=device, play_steps=play_steps)
134
+
135
+ generation_kwargs = dict(
136
+ **inputs.to(device),
137
+ streamer=streamer,
138
+ max_new_tokens=max_new_tokens,
139
+ )
140
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
141
+ thread.start()
142
+
143
+ set_seed(seed)
144
+ generated_audio = []
145
+ for new_audio in streamer:
146
+ generated_audio.append(new_audio)
147
+
148
+ # Concatenate the audio chunks
149
+ final_audio = np.concatenate(generated_audio)
150
+
151
+ # Save the audio to a buffer and send it as a response
152
+ buffer = io.BytesIO()
153
+ sf.write(buffer, final_audio, sampling_rate, format="wav")
154
+ buffer.seek(0)
155
+
156
+ return send_file(buffer, mimetype="audio/wav", as_attachment=True, download_name="generated_music.wav")
157
+
158
+ if __name__ == '__main__':
159
+ app.run(host='0.0.0.0', port=8000)