bluenevus commited on
Commit
db5919c
·
verified ·
1 Parent(s): c10cafd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -35
app.py CHANGED
@@ -3,14 +3,16 @@ import google.generativeai as genai
3
  import numpy as np
4
  import re
5
  import torch
6
- import torchaudio
7
- import torchaudio.functional as F
8
  from transformers import AutoModelForCausalLM, AutoTokenizer
9
  from huggingface_hub import snapshot_download, login
10
  import logging
11
  import os
12
  import spaces
13
  import warnings
 
 
 
 
14
 
15
  # Set up logging
16
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
@@ -21,19 +23,22 @@ warnings.filterwarnings("ignore", category=UserWarning)
21
  warnings.filterwarnings("ignore", category=RuntimeWarning)
22
 
23
  def get_device():
24
- if torch.cuda.is_available():
25
- return torch.device("cuda")
26
- return torch.device("cpu")
27
 
28
  device = get_device()
29
  logger.info(f"Using device: {device}")
30
 
31
  model = None
32
  tokenizer = None
 
33
 
34
  @spaces.GPU()
35
  def load_model():
36
- global model, tokenizer
 
 
 
 
37
 
38
  logger.info("Loading Orpheus model...")
39
  model_name = "canopylabs/orpheus-3b-0.1-ft"
@@ -67,7 +72,7 @@ def load_model():
67
  ]
68
  )
69
 
70
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32 if device.type == 'cpu' else torch.bfloat16)
71
  model.to(device)
72
  tokenizer = AutoTokenizer.from_pretrained(model_name)
73
  logger.info(f"Orpheus model and tokenizer loaded to {device}")
@@ -108,42 +113,102 @@ def generate_podcast_script(api_key, content, uploaded_file, duration, num_hosts
108
  logger.error(f"Error generating podcast script: {str(e)}")
109
  raise
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  @spaces.GPU()
112
- def text_to_speech(text, voice):
113
- global model, tokenizer
 
 
 
114
  try:
115
- if model is None or tokenizer is None:
116
- load_model()
117
 
118
- # Remove emotion tags for TTS processing
119
- clean_text = re.sub(r'<[^>]+>', '', text)
120
-
121
- inputs = tokenizer(clean_text, return_tensors="pt").to(device)
122
  with torch.no_grad():
123
- output = model.generate(**inputs, max_new_tokens=256)
124
-
125
- # Convert output tensor to mel spectrogram
126
- mel = output[0].cpu()
127
-
128
- # Reshape mel to match expected dimensions
129
- n_mels = 80 # Typical number of mel bands
130
- time_dim = mel.shape[0]
131
- mel_reshaped = mel.view(n_mels, -1)
 
 
132
 
133
- # Normalize the mel spectrogram
134
- mel_reshaped = (mel_reshaped - mel_reshaped.min()) / (mel_reshaped.max() - mel_reshaped.min())
135
 
136
- # Convert mel spectrogram to audio using torchaudio
137
- audio = F.griffinlim(mel_reshaped.unsqueeze(0), n_iter=10, n_fft=2048, hop_length=512, win_length=2048)
138
-
139
- # Convert to numpy array and ensure it's in the correct format
140
- audio_np = audio.squeeze().numpy()
141
- audio_np = np.clip(audio_np, -1, 1)
142
-
143
- return (24000, audio_np.astype(np.float32)) # Assuming 24kHz sample rate
144
  except Exception as e:
145
  logger.error(f"Error in text_to_speech: {str(e)}")
146
  raise
 
147
  @spaces.GPU()
148
  def render_podcast(api_key, script, voice1, voice2, num_hosts):
149
  try:
@@ -153,7 +218,7 @@ def render_podcast(api_key, script, voice1, voice2, num_hosts):
153
  for i, line in enumerate(lines):
154
  voice = voice1 if num_hosts == 1 or i % 2 == 0 else voice2
155
  try:
156
- _, audio = text_to_speech(line, voice)
157
  audio_segments.append(audio)
158
  except Exception as e:
159
  logger.error(f"Error processing audio segment: {str(e)}")
@@ -173,6 +238,7 @@ def render_podcast(api_key, script, voice1, voice2, num_hosts):
173
  logger.error(f"Error rendering podcast: {str(e)}")
174
  raise
175
 
 
176
  with gr.Blocks() as demo:
177
  gr.Markdown("# AI Podcast Generator")
178
 
 
3
  import numpy as np
4
  import re
5
  import torch
 
 
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
  from huggingface_hub import snapshot_download, login
8
  import logging
9
  import os
10
  import spaces
11
  import warnings
12
+ from snac import SNAC
13
+ from dotenv import load_dotenv
14
+
15
+ load_dotenv()
16
 
17
  # Set up logging
18
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
 
23
  warnings.filterwarnings("ignore", category=RuntimeWarning)
24
 
25
  def get_device():
26
+ return "cuda" if torch.cuda.is_available() else "cpu"
 
 
27
 
28
  device = get_device()
29
  logger.info(f"Using device: {device}")
30
 
31
  model = None
32
  tokenizer = None
33
+ snac_model = None
34
 
35
  @spaces.GPU()
36
  def load_model():
37
+ global model, tokenizer, snac_model
38
+
39
+ logger.info("Loading SNAC model...")
40
+ snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
41
+ snac_model = snac_model.to(device)
42
 
43
  logger.info("Loading Orpheus model...")
44
  model_name = "canopylabs/orpheus-3b-0.1-ft"
 
72
  ]
73
  )
74
 
75
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
76
  model.to(device)
77
  tokenizer = AutoTokenizer.from_pretrained(model_name)
78
  logger.info(f"Orpheus model and tokenizer loaded to {device}")
 
113
  logger.error(f"Error generating podcast script: {str(e)}")
114
  raise
115
 
116
+ def process_prompt(prompt, voice, tokenizer, device):
117
+ prompt = f"{voice}: {prompt}"
118
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
119
+
120
+ start_token = torch.tensor([[128259]], dtype=torch.int64) # Start of human
121
+ end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) # End of text, End of human
122
+
123
+ modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1) # SOH SOT Text EOT EOH
124
+
125
+ # No padding needed for single input
126
+ attention_mask = torch.ones_like(modified_input_ids)
127
+
128
+ return modified_input_ids.to(device), attention_mask.to(device)
129
+
130
+ def parse_output(generated_ids):
131
+ token_to_find = 128257
132
+ token_to_remove = 128258
133
+
134
+ token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
135
+
136
+ if len(token_indices[1]) > 0:
137
+ last_occurrence_idx = token_indices[1][-1].item()
138
+ cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
139
+ else:
140
+ cropped_tensor = generated_ids
141
+
142
+ processed_rows = []
143
+ for row in cropped_tensor:
144
+ masked_row = row[row != token_to_remove]
145
+ processed_rows.append(masked_row)
146
+
147
+ code_lists = []
148
+ for row in processed_rows:
149
+ row_length = row.size(0)
150
+ new_length = (row_length // 7) * 7
151
+ trimmed_row = row[:new_length]
152
+ trimmed_row = [t - 128266 for t in trimmed_row]
153
+ code_lists.append(trimmed_row)
154
+
155
+ return code_lists[0] # Return just the first one for single sample
156
+
157
+ def redistribute_codes(code_list, snac_model):
158
+ device = next(snac_model.parameters()).device # Get the device of SNAC model
159
+
160
+ layer_1 = []
161
+ layer_2 = []
162
+ layer_3 = []
163
+ for i in range((len(code_list)+1)//7):
164
+ layer_1.append(code_list[7*i])
165
+ layer_2.append(code_list[7*i+1]-4096)
166
+ layer_3.append(code_list[7*i+2]-(2*4096))
167
+ layer_3.append(code_list[7*i+3]-(3*4096))
168
+ layer_2.append(code_list[7*i+4]-(4*4096))
169
+ layer_3.append(code_list[7*i+5]-(5*4096))
170
+ layer_3.append(code_list[7*i+6]-(6*4096))
171
+
172
+ # Move tensors to the same device as the SNAC model
173
+ codes = [
174
+ torch.tensor(layer_1, device=device).unsqueeze(0),
175
+ torch.tensor(layer_2, device=device).unsqueeze(0),
176
+ torch.tensor(layer_3, device=device).unsqueeze(0)
177
+ ]
178
+
179
+ audio_hat = snac_model.decode(codes)
180
+ return audio_hat.detach().squeeze().cpu().numpy() # Always return CPU numpy array
181
+
182
  @spaces.GPU()
183
+ def text_to_speech(text, voice, temperature=0.6, top_p=0.95, repetition_penalty=1.1, max_new_tokens=1200):
184
+ global model, tokenizer, snac_model
185
+ if not text.strip():
186
+ return None
187
+
188
  try:
189
+ input_ids, attention_mask = process_prompt(text, voice, tokenizer, device)
 
190
 
 
 
 
 
191
  with torch.no_grad():
192
+ generated_ids = model.generate(
193
+ input_ids=input_ids,
194
+ attention_mask=attention_mask,
195
+ max_new_tokens=max_new_tokens,
196
+ do_sample=True,
197
+ temperature=temperature,
198
+ top_p=top_p,
199
+ repetition_penalty=repetition_penalty,
200
+ num_return_sequences=1,
201
+ eos_token_id=128258,
202
+ )
203
 
204
+ code_list = parse_output(generated_ids)
205
+ audio_samples = redistribute_codes(code_list, snac_model)
206
 
207
+ return (24000, audio_samples) # Return sample rate and audio
 
 
 
 
 
 
 
208
  except Exception as e:
209
  logger.error(f"Error in text_to_speech: {str(e)}")
210
  raise
211
+
212
  @spaces.GPU()
213
  def render_podcast(api_key, script, voice1, voice2, num_hosts):
214
  try:
 
218
  for i, line in enumerate(lines):
219
  voice = voice1 if num_hosts == 1 or i % 2 == 0 else voice2
220
  try:
221
+ sample_rate, audio = text_to_speech(line, voice)
222
  audio_segments.append(audio)
223
  except Exception as e:
224
  logger.error(f"Error processing audio segment: {str(e)}")
 
238
  logger.error(f"Error rendering podcast: {str(e)}")
239
  raise
240
 
241
+ # Gradio Interface
242
  with gr.Blocks() as demo:
243
  gr.Markdown("# AI Podcast Generator")
244