Mohaddz commited on
Commit
580bcb7
·
verified ·
1 Parent(s): e360346

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +198 -105
app.py CHANGED
@@ -3,66 +3,102 @@ from snac import SNAC
3
  import torch
4
  import gradio as gr
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
- from huggingface_hub import snapshot_download
7
  from dotenv import load_dotenv
 
 
8
  load_dotenv()
9
 
 
 
 
 
 
 
 
 
10
  # Check if CUDA is available
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
12
 
13
  print("Loading SNAC model...")
14
- snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
15
- snac_model = snac_model.to(device)
16
-
17
- model_name = "Mohaddz/orpheus-3b-0.1-ft-ar"
18
-
19
- # Download only model config and safetensors
20
- snapshot_download(
21
- repo_id=model_name,
22
- allow_patterns=[
23
- "config.json",
24
- "*.safetensors",
25
- "model.safetensors.index.json",
26
- ],
27
- ignore_patterns=[
28
- "optimizer.pt",
29
- "pytorch_model.bin",
30
- "training_args.bin",
31
- "scheduler.pt",
32
- "tokenizer.json",
33
- "tokenizer_config.json",
34
- "special_tokens_map.json",
35
- "vocab.json",
36
- "merges.txt",
37
- "tokenizer.*"
38
- ]
39
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
42
- model.to(device)
43
- tokenizer = AutoTokenizer.from_pretrained(model_name)
44
- print(f"Orpheus model loaded to {device}")
45
 
46
- # Process text prompt
47
- def process_prompt(prompt, voice, tokenizer, device):
 
 
48
  prompt = f"{voice}: {prompt}"
49
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids
50
-
51
  start_token = torch.tensor([[128259]], dtype=torch.int64) # Start of human
52
  end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) # End of text, End of human
53
-
54
  modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1) # SOH SOT Text EOT EOH
55
-
56
- # No padding needed for single input
57
  attention_mask = torch.ones_like(modified_input_ids)
58
-
59
  return modified_input_ids.to(device), attention_mask.to(device)
60
 
61
- # Parse output tokens to audio
62
  def parse_output(generated_ids):
63
  token_to_find = 128257
64
  token_to_remove = 128258
65
-
66
  token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
67
 
68
  if len(token_indices[1]) > 0:
@@ -81,19 +117,23 @@ def parse_output(generated_ids):
81
  row_length = row.size(0)
82
  new_length = (row_length // 7) * 7
83
  trimmed_row = row[:new_length]
84
- trimmed_row = [t - 128266 for t in trimmed_row]
85
  code_lists.append(trimmed_row)
86
-
87
- return code_lists[0] # Return just the first one for single sample
88
 
89
- # Redistribute codes for audio generation
90
- def redistribute_codes(code_list, snac_model):
91
- device = next(snac_model.parameters()).device # Get the device of SNAC model
92
-
 
 
 
 
 
93
  layer_1 = []
94
  layer_2 = []
95
  layer_3 = []
96
- for i in range((len(code_list)+1)//7):
 
97
  layer_1.append(code_list[7*i])
98
  layer_2.append(code_list[7*i+1]-4096)
99
  layer_3.append(code_list[7*i+2]-(2*4096))
@@ -101,137 +141,190 @@ def redistribute_codes(code_list, snac_model):
101
  layer_2.append(code_list[7*i+4]-(4*4096))
102
  layer_3.append(code_list[7*i+5]-(5*4096))
103
  layer_3.append(code_list[7*i+6]-(6*4096))
104
-
105
- # Move tensors to the same device as the SNAC model
 
 
 
106
  codes = [
107
- torch.tensor(layer_1, device=device).unsqueeze(0),
108
- torch.tensor(layer_2, device=device).unsqueeze(0),
109
- torch.tensor(layer_3, device=device).unsqueeze(0)
110
  ]
111
-
112
- audio_hat = snac_model.decode(codes)
113
- return audio_hat.detach().squeeze().cpu().numpy() # Always return CPU numpy array
114
 
115
- # Main generation function
 
 
 
 
 
116
  @spaces.GPU()
117
- def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new_tokens, progress=gr.Progress()):
 
 
 
 
 
 
 
 
118
  if not text.strip():
 
119
  return None
120
-
121
  try:
122
  progress(0.1, "Processing text...")
123
- input_ids, attention_mask = process_prompt(text, voice, tokenizer, device)
124
-
125
  progress(0.3, "Generating speech tokens...")
126
  with torch.no_grad():
127
- generated_ids = model.generate(
 
128
  input_ids=input_ids,
129
  attention_mask=attention_mask,
130
  max_new_tokens=max_new_tokens,
131
  do_sample=True,
132
- temperature=temperature,
133
  top_p=top_p,
134
  repetition_penalty=repetition_penalty,
135
  num_return_sequences=1,
136
- eos_token_id=128258,
 
137
  )
138
-
139
  progress(0.6, "Processing speech tokens...")
140
  code_list = parse_output(generated_ids)
141
-
142
  progress(0.8, "Converting to audio...")
143
  audio_samples = redistribute_codes(code_list, snac_model)
144
-
 
 
 
 
145
  return (24000, audio_samples) # Return sample rate and audio
146
  except Exception as e:
147
  print(f"Error generating speech: {e}")
 
 
 
148
  return None
149
 
 
 
 
 
 
 
 
 
150
  # Examples for the UI
151
  examples = [
152
- ["Hey there my name is Tara, <chuckle> and I'm a speech generation model that can sound like a person.", "tara", 0.6, 0.95, 1.1, 1200],
153
- ["I've also been taught to understand and produce paralinguistic things like sighing, or chuckling, or yawning!", "dan", 0.7, 0.95, 1.1, 1200],
154
- ["I live in San Francisco, and have, uhm let's see, 3 billion 7 hundred ... well, lets just say a lot of parameters.", "emma", 0.6, 0.9, 1.2, 1200]
 
155
  ]
156
 
157
- # Available voices
158
- VOICES = ["tara", "dan", "josh", "emma"]
 
159
 
160
  # Create Gradio interface
161
  with gr.Blocks(title="Orpheus Text-to-Speech") as demo:
162
  gr.Markdown("""
163
- # 🎵 [Orpheus Text-to-Speech](https://github.com/canopyai/Orpheus-TTS)
164
- Enter your text below and hear it converted to natural-sounding speech with the Orpheus TTS model.
165
-
166
- ## Tips for better prompts:
167
- - Add paralinguistic elements like `<chuckle>`, `<sigh>`, or `uhm` for more human-like speech.
168
- - Longer text prompts generally work better than very short phrases
169
- - Adjust the temperature slider for more varied (higher) or consistent (lower) speech patterns
170
- """)
 
 
 
 
 
 
 
 
 
171
  with gr.Row():
172
  with gr.Column(scale=3):
173
  text_input = gr.Textbox(
174
- label="Text to speak",
175
- placeholder="Enter your text here...",
176
- lines=5
 
177
  )
178
  voice = gr.Dropdown(
179
- choices=VOICES,
180
- value="tara",
181
- label="Voice"
182
  )
183
-
184
- with gr.Accordion("Advanced Settings", open=False):
185
  temperature = gr.Slider(
186
  minimum=0.1, maximum=1.5, value=0.6, step=0.05,
187
- label="Temperature",
188
  info="Higher values (0.7-1.0) create more expressive but less stable speech"
189
  )
190
  top_p = gr.Slider(
191
  minimum=0.1, maximum=1.0, value=0.95, step=0.05,
192
- label="Top P",
193
  info="Nucleus sampling threshold"
194
  )
195
  repetition_penalty = gr.Slider(
196
  minimum=1.0, maximum=2.0, value=1.1, step=0.05,
197
- label="Repetition Penalty",
198
  info="Higher values discourage repetitive patterns"
199
  )
200
  max_new_tokens = gr.Slider(
201
  minimum=100, maximum=2000, value=1200, step=100,
202
- label="Max Length",
203
  info="Maximum length of generated audio (in tokens)"
204
  )
205
-
206
  with gr.Row():
207
- submit_btn = gr.Button("Generate Speech", variant="primary")
208
- clear_btn = gr.Button("Clear")
209
-
210
  with gr.Column(scale=2):
211
- audio_output = gr.Audio(label="Generated Speech", type="numpy")
212
-
213
  # Set up examples
214
  gr.Examples(
215
  examples=examples,
216
  inputs=[text_input, voice, temperature, top_p, repetition_penalty, max_new_tokens],
217
  outputs=audio_output,
218
- fn=generate_speech,
219
- cache_examples=True,
 
 
 
 
 
 
 
 
220
  )
221
-
222
- # Set up event handlers
223
  submit_btn.click(
224
  fn=generate_speech,
225
  inputs=[text_input, voice, temperature, top_p, repetition_penalty, max_new_tokens],
226
  outputs=audio_output
227
  )
228
-
 
229
  clear_btn.click(
230
  fn=lambda: (None, None),
231
  inputs=[],
232
  outputs=[text_input, audio_output]
233
  )
 
234
 
235
  # Launch the app
236
  if __name__ == "__main__":
237
- demo.queue().launch(share=False, ssr_mode=False)
 
3
  import torch
4
  import gradio as gr
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
+ # Removed snapshot_download as from_pretrained handles caching
7
  from dotenv import load_dotenv
8
+ import gc # Import garbage collector for memory management
9
+
10
  load_dotenv()
11
 
12
+ # --- Global Variables ---
13
+ current_model = None
14
+ current_tokenizer = None
15
+ current_model_name = None
16
+ model_choices = ["Mohaddz/orpheus-3b-0.1-ft-ar", "Mohaddz/orpheus-arabic-exp"]
17
+ default_model_name = "Mohaddz/orpheus-3b-0.1-ft-ar" # Or your preferred default
18
+ # --- End Global Variables ---
19
+
20
  # Check if CUDA is available
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
+ dtype = torch.bfloat16 if device == "cuda" else torch.float32 # Use float32 on CPU
23
 
24
  print("Loading SNAC model...")
25
+ try:
26
+ snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
27
+ snac_model = snac_model.to(device)
28
+ print("SNAC model loaded.")
29
+ except Exception as e:
30
+ print(f"Error loading SNAC model: {e}")
31
+ snac_model = None # Handle case where SNAC fails
32
+
33
+ # --- Model Loading Function ---
34
+ def load_model_and_tokenizer(model_name_to_load, progress=gr.Progress(track_tqdm=True)):
35
+ global current_model, current_tokenizer, current_model_name, device, dtype
36
+
37
+ if model_name_to_load == current_model_name and current_model is not None:
38
+ print(f"Model {model_name_to_load} is already loaded.")
39
+ gr.Info(f"Model {model_name_to_load} is already loaded.")
40
+ return f"Model {model_name_to_load} already loaded." # Return status message
41
+
42
+ print(f"Unloading previous model if exists...")
43
+ # Explicitly delete previous model and clear cache to free VRAM
44
+ if current_model is not None:
45
+ del current_model
46
+ current_model = None
47
+ if current_tokenizer is not None:
48
+ del current_tokenizer
49
+ current_tokenizer = None
50
+ gc.collect() # Run garbage collection
51
+ if device == "cuda":
52
+ torch.cuda.empty_cache() # Clear CUDA cache
53
+
54
+ print(f"Loading Orpheus model: {model_name_to_load}...")
55
+ try:
56
+ # Use from_pretrained which handles download and caching
57
+ new_model = AutoModelForCausalLM.from_pretrained(model_name_to_load, torch_dtype=dtype)
58
+ new_model.to(device)
59
+ new_tokenizer = AutoTokenizer.from_pretrained(model_name_to_load)
60
+
61
+ # Update global variables
62
+ current_model = new_model
63
+ current_tokenizer = new_tokenizer
64
+ current_model_name = model_name_to_load
65
+
66
+ print(f"Orpheus model {current_model_name} loaded successfully to {device}")
67
+ gr.Info(f"Model {current_model_name} loaded.")
68
+ return f"Model {current_model_name} loaded." # Return status message
69
+
70
+ except Exception as e:
71
+ print(f"Error loading model {model_name_to_load}: {e}")
72
+ # Reset globals if loading fails
73
+ current_model = None
74
+ current_tokenizer = None
75
+ current_model_name = None
76
+ gr.Warning(f"Failed to load model {model_name_to_load}. Please try again or select another model.")
77
+ return f"Error loading {model_name_to_load}." # Return status message
78
+ # --- End Model Loading Function ---
79
 
 
 
 
 
80
 
81
+ # Process text prompt (Uses global tokenizer now)
82
+ def process_prompt(prompt, voice, device):
83
+ if current_tokenizer is None:
84
+ raise ValueError("Tokenizer not loaded.")
85
  prompt = f"{voice}: {prompt}"
86
+ input_ids = current_tokenizer(prompt, return_tensors="pt").input_ids
87
+
88
  start_token = torch.tensor([[128259]], dtype=torch.int64) # Start of human
89
  end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) # End of text, End of human
90
+
91
  modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1) # SOH SOT Text EOT EOH
92
+
 
93
  attention_mask = torch.ones_like(modified_input_ids)
94
+
95
  return modified_input_ids.to(device), attention_mask.to(device)
96
 
97
+ # Parse output tokens to audio (no change needed)
98
  def parse_output(generated_ids):
99
  token_to_find = 128257
100
  token_to_remove = 128258
101
+
102
  token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
103
 
104
  if len(token_indices[1]) > 0:
 
117
  row_length = row.size(0)
118
  new_length = (row_length // 7) * 7
119
  trimmed_row = row[:new_length]
120
+ trimmed_row = [t - 128266 for t in trimmed_row] # Adjust based on actual token IDs if needed
121
  code_lists.append(trimmed_row)
 
 
122
 
123
+ return code_lists[0] if code_lists else [] # Handle empty case
124
+
125
+ # Redistribute codes for audio generation (no change needed)
126
+ def redistribute_codes(code_list, snac_model_instance):
127
+ if not snac_model_instance or not code_list:
128
+ print("SNAC model not loaded or code list empty.")
129
+ return None
130
+ snac_device = next(snac_model_instance.parameters()).device
131
+
132
  layer_1 = []
133
  layer_2 = []
134
  layer_3 = []
135
+ num_frames = len(code_list) // 7 # Use integer division
136
+ for i in range(num_frames):
137
  layer_1.append(code_list[7*i])
138
  layer_2.append(code_list[7*i+1]-4096)
139
  layer_3.append(code_list[7*i+2]-(2*4096))
 
141
  layer_2.append(code_list[7*i+4]-(4*4096))
142
  layer_3.append(code_list[7*i+5]-(5*4096))
143
  layer_3.append(code_list[7*i+6]-(6*4096))
144
+
145
+ if not layer_1: # Check if any codes were processed
146
+ print("No valid frames found in code list.")
147
+ return None
148
+
149
  codes = [
150
+ torch.tensor(layer_1, device=snac_device).unsqueeze(0),
151
+ torch.tensor(layer_2, device=snac_device).unsqueeze(0),
152
+ torch.tensor(layer_3, device=snac_device).unsqueeze(0)
153
  ]
 
 
 
154
 
155
+ with torch.no_grad():
156
+ audio_hat = snac_model_instance.decode(codes)
157
+ return audio_hat.detach().squeeze().cpu().numpy()
158
+
159
+
160
+ # Main generation function (Uses global model now)
161
  @spaces.GPU()
162
+ def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new_tokens, progress=gr.Progress(track_tqdm=True)):
163
+ global current_model, device # Access globals
164
+
165
+ if current_model is None or current_tokenizer is None:
166
+ gr.Warning("Orpheus model not loaded. Please select a model and wait for it to load.")
167
+ return None
168
+ if snac_model is None:
169
+ gr.Warning("SNAC vocoder model failed to load. Cannot generate audio.")
170
+ return None
171
  if not text.strip():
172
+ gr.Info("Please enter some text.")
173
  return None
174
+
175
  try:
176
  progress(0.1, "Processing text...")
177
+ input_ids, attention_mask = process_prompt(text, voice, device)
178
+
179
  progress(0.3, "Generating speech tokens...")
180
  with torch.no_grad():
181
+ # Make sure generation parameters are appropriate
182
+ generated_ids = current_model.generate(
183
  input_ids=input_ids,
184
  attention_mask=attention_mask,
185
  max_new_tokens=max_new_tokens,
186
  do_sample=True,
187
+ temperature=max(temperature, 0.01), # Ensure temp is not zero
188
  top_p=top_p,
189
  repetition_penalty=repetition_penalty,
190
  num_return_sequences=1,
191
+ eos_token_id=128258, # Make sure this is correct for the models
192
+ pad_token_id=current_tokenizer.pad_token_id if current_tokenizer.pad_token_id is not None else current_tokenizer.eos_token_id # Use tokenizer's pad/eos token
193
  )
194
+
195
  progress(0.6, "Processing speech tokens...")
196
  code_list = parse_output(generated_ids)
197
+
198
  progress(0.8, "Converting to audio...")
199
  audio_samples = redistribute_codes(code_list, snac_model)
200
+
201
+ if audio_samples is None:
202
+ gr.Warning("Failed to generate audio samples.")
203
+ return None
204
+
205
  return (24000, audio_samples) # Return sample rate and audio
206
  except Exception as e:
207
  print(f"Error generating speech: {e}")
208
+ import traceback
209
+ traceback.print_exc() # Print full traceback for debugging
210
+ gr.Error(f"An error occurred during generation: {e}")
211
  return None
212
 
213
+ # --- Load Default Model at Startup ---
214
+ # Moved initial loading to happen *before* launching the UI
215
+ # This ensures a model is ready when the interface appears.
216
+ print("Loading default model...")
217
+ initial_status = load_model_and_tokenizer(default_model_name)
218
+ print(initial_status)
219
+ # --- End Load Default Model ---
220
+
221
  # Examples for the UI
222
  examples = [
223
+ # Examples might need adjusting if voices/behavior differ between models
224
+ ["السلام عليكم كيف حالكم اليوم؟", "tara", 0.6, 0.95, 1.1, 1200],
225
+ ["أنا نموذج لتحويل النص إلى كلام يمكنه التحدث باللغة العربية.", "dan", 0.7, 0.95, 1.1, 1200],
226
+ # ["I live in San Francisco, and have, uhm let's see, 3 billion 7 hundred ... well, lets just say a lot of parameters.", "emma", 0.6, 0.9, 1.2, 1200] # Keep or remove English examples
227
  ]
228
 
229
+ # Available voices (Might need updating based on your fine-tuned models)
230
+ # You might need different voice lists per model, or just use 'tara'/'dan' if they exist in both
231
+ VOICES = ["tara", "dan", "josh", "emma"] # Adjust as needed
232
 
233
  # Create Gradio interface
234
  with gr.Blocks(title="Orpheus Text-to-Speech") as demo:
235
  gr.Markdown("""
236
+ # 🎵 Orpheus Text-to-Speech (Arabic Fine-tuned)
237
+ Enter your text below and hear it converted to natural-sounding speech.
238
+ Select the desired fine-tuned model below.
239
+ """)
240
+
241
+ with gr.Row():
242
+ # Model Selection Dropdown
243
+ model_selector = gr.Dropdown(
244
+ choices=model_choices,
245
+ value=current_model_name, # Default to the loaded model
246
+ label="Select Fine-Tuned Model",
247
+ interactive=True
248
+ )
249
+ # Status Textbox (Optional)
250
+ status_display = gr.Textbox(label="Model Status", value=initial_status, interactive=False)
251
+
252
+
253
  with gr.Row():
254
  with gr.Column(scale=3):
255
  text_input = gr.Textbox(
256
+ label="Text to speak (النص)",
257
+ placeholder="أدخل النص هنا...",
258
+ lines=5,
259
+ text_align="right" # Align text right for Arabic
260
  )
261
  voice = gr.Dropdown(
262
+ choices=VOICES,
263
+ value="tara", # Default voice
264
+ label="Voice (الصوت)"
265
  )
266
+
267
+ with gr.Accordion("Advanced Settings (إعدادات متقدمة)", open=False):
268
  temperature = gr.Slider(
269
  minimum=0.1, maximum=1.5, value=0.6, step=0.05,
270
+ label="Temperature (درجة الحرارة)",
271
  info="Higher values (0.7-1.0) create more expressive but less stable speech"
272
  )
273
  top_p = gr.Slider(
274
  minimum=0.1, maximum=1.0, value=0.95, step=0.05,
275
+ label="Top P",
276
  info="Nucleus sampling threshold"
277
  )
278
  repetition_penalty = gr.Slider(
279
  minimum=1.0, maximum=2.0, value=1.1, step=0.05,
280
+ label="Repetition Penalty (عقوبة التكرار)",
281
  info="Higher values discourage repetitive patterns"
282
  )
283
  max_new_tokens = gr.Slider(
284
  minimum=100, maximum=2000, value=1200, step=100,
285
+ label="Max Length (الطول الأقصى)",
286
  info="Maximum length of generated audio (in tokens)"
287
  )
288
+
289
  with gr.Row():
290
+ submit_btn = gr.Button("Generate Speech (توليد الكلام)", variant="primary")
291
+ clear_btn = gr.Button("Clear (مسح)")
292
+
293
  with gr.Column(scale=2):
294
+ audio_output = gr.Audio(label="Generated Speech (الكلام المولّد)", type="numpy")
295
+
296
  # Set up examples
297
  gr.Examples(
298
  examples=examples,
299
  inputs=[text_input, voice, temperature, top_p, repetition_penalty, max_new_tokens],
300
  outputs=audio_output,
301
+ fn=generate_speech, # Function to call for examples
302
+ cache_examples=False, # Disable caching if models change behavior
303
+ )
304
+
305
+ # --- Event Handlers ---
306
+ # Trigger model loading when dropdown changes
307
+ model_selector.change(
308
+ fn=load_model_and_tokenizer,
309
+ inputs=[model_selector],
310
+ outputs=[status_display] # Update status display
311
  )
312
+
313
+ # Generate speech button click
314
  submit_btn.click(
315
  fn=generate_speech,
316
  inputs=[text_input, voice, temperature, top_p, repetition_penalty, max_new_tokens],
317
  outputs=audio_output
318
  )
319
+
320
+ # Clear button click
321
  clear_btn.click(
322
  fn=lambda: (None, None),
323
  inputs=[],
324
  outputs=[text_input, audio_output]
325
  )
326
+ # --- End Event Handlers ---
327
 
328
  # Launch the app
329
  if __name__ == "__main__":
330
+ demo.queue().launch(share=False) # Removed ssr_mode=False, queue is usually enough