bluenevus commited on
Commit
e3bea0f
·
verified ·
1 Parent(s): 33f9554

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +147 -45
app.py CHANGED
@@ -1,4 +1,4 @@
1
- import gradio as gr
2
  import google.generativeai as genai
3
  import numpy as np
4
  import re
@@ -14,62 +14,40 @@ from dotenv import load_dotenv
14
 
15
  load_dotenv()
16
 
17
- # Set up logging
18
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
19
  logger = logging.getLogger(__name__)
20
 
21
- # Suppress specific warnings
22
  warnings.filterwarnings("ignore", category=UserWarning)
23
  warnings.filterwarnings("ignore", category=RuntimeWarning)
24
 
25
- def get_device():
26
- return "cuda" if torch.cuda.is_available() else "cpu"
27
-
28
- device = get_device()
29
  logger.info(f"Using device: {device}")
30
 
31
  model = None
32
  tokenizer = None
33
  snac_model = None
34
 
35
- @spaces.GPU()
36
  def load_model():
37
  global model, tokenizer, snac_model
38
-
39
- logger.info("Loading SNAC model...")
40
- snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
41
- snac_model = snac_model.to(device)
42
-
43
- logger.info("Loading Orpheus model...")
44
- model_name = "canopylabs/orpheus-3b-0.1-ft"
45
 
46
- hf_token = os.environ.get("HUGGINGFACE_TOKEN")
47
- if not hf_token:
48
- raise ValueError("HUGGINGFACE_TOKEN environment variable is not set")
49
 
50
- try:
51
  login(token=hf_token)
52
 
53
  snapshot_download(
54
  repo_id=model_name,
55
  use_auth_token=hf_token,
56
- allow_patterns=[
57
- "config.json",
58
- "*.safetensors",
59
- "model.safetensors.index.json",
60
- ],
61
- ignore_patterns=[
62
- "optimizer.pt",
63
- "pytorch_model.bin",
64
- "training_args.bin",
65
- "scheduler.pt",
66
- "tokenizer.json",
67
- "tokenizer_config.json",
68
- "special_tokens_map.json",
69
- "vocab.json",
70
- "merges.txt",
71
- "tokenizer.*"
72
- ]
73
  )
74
 
75
  model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
@@ -80,7 +58,100 @@ def load_model():
80
  logger.error(f"Error loading model: {str(e)}")
81
  raise
82
 
83
- @spaces.GPU()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  def text_to_speech(text, voice, temperature=0.6, top_p=0.95, repetition_penalty=1.1, max_new_tokens=1200):
85
  global model, tokenizer, snac_model
86
  if model is None or tokenizer is None or snac_model is None:
@@ -108,12 +179,11 @@ def text_to_speech(text, voice, temperature=0.6, top_p=0.95, repetition_penalty=
108
  code_list = parse_output(generated_ids)
109
  audio_samples = redistribute_codes(code_list, snac_model)
110
 
111
- return (24000, audio_samples) # Return sample rate and audio
112
  except Exception as e:
113
  logger.error(f"Error in text_to_speech: {str(e)}")
114
  raise
115
 
116
- @spaces.GPU()
117
  def render_podcast(api_key, script, voice1, voice2, num_hosts):
118
  try:
119
  lines = [line for line in script.split('\n') if line.strip()]
@@ -122,8 +192,10 @@ def render_podcast(api_key, script, voice1, voice2, num_hosts):
122
  for i, line in enumerate(lines):
123
  voice = voice1 if num_hosts == 1 or i % 2 == 0 else voice2
124
  try:
125
- sample_rate, audio = text_to_speech(line, voice)
126
- audio_segments.append(audio)
 
 
127
  except Exception as e:
128
  logger.error(f"Error processing audio segment: {str(e)}")
129
 
@@ -132,8 +204,6 @@ def render_podcast(api_key, script, voice1, voice2, num_hosts):
132
  return (24000, np.zeros(24000, dtype=np.float32))
133
 
134
  podcast_audio = np.concatenate(audio_segments)
135
-
136
- # Ensure the audio is in the correct format for Gradio
137
  podcast_audio = np.clip(podcast_audio, -1, 1)
138
  podcast_audio = (podcast_audio * 32767).astype(np.int16)
139
 
@@ -142,11 +212,43 @@ def render_podcast(api_key, script, voice1, voice2, num_hosts):
142
  logger.error(f"Error rendering podcast: {str(e)}")
143
  raise
144
 
145
- # ... (rest of the code remains the same)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
  if __name__ == "__main__":
148
  try:
149
- load_model() # Load models at startup
150
  demo.launch()
151
  except Exception as e:
152
  logger.error(f"Error launching the application: {str(e)}")
 
1
+ ]import gradio as gr
2
  import google.generativeai as genai
3
  import numpy as np
4
  import re
 
14
 
15
  load_dotenv()
16
 
 
17
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
18
  logger = logging.getLogger(__name__)
19
 
 
20
  warnings.filterwarnings("ignore", category=UserWarning)
21
  warnings.filterwarnings("ignore", category=RuntimeWarning)
22
 
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
24
  logger.info(f"Using device: {device}")
25
 
26
  model = None
27
  tokenizer = None
28
  snac_model = None
29
 
 
30
  def load_model():
31
  global model, tokenizer, snac_model
32
+ try:
33
+ logger.info("Loading SNAC model...")
34
+ snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
35
+ snac_model = snac_model.to(device)
36
+
37
+ logger.info("Loading Orpheus model...")
38
+ model_name = "canopylabs/orpheus-3b-0.1-ft"
39
 
40
+ hf_token = os.environ.get("HUGGINGFACE_TOKEN")
41
+ if not hf_token:
42
+ raise ValueError("HUGGINGFACE_TOKEN environment variable is not set")
43
 
 
44
  login(token=hf_token)
45
 
46
  snapshot_download(
47
  repo_id=model_name,
48
  use_auth_token=hf_token,
49
+ allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json"],
50
+ ignore_patterns=["optimizer.pt", "pytorch_model.bin", "training_args.bin", "scheduler.pt", "tokenizer.*"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  )
52
 
53
  model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
 
58
  logger.error(f"Error loading model: {str(e)}")
59
  raise
60
 
61
+ def generate_podcast_script(api_key, content, uploaded_file, duration, num_hosts):
62
+ try:
63
+ genai.configure(api_key=api_key)
64
+ model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
65
+
66
+ combined_content = content or ""
67
+ if uploaded_file:
68
+ file_content = uploaded_file.read().decode('utf-8')
69
+ combined_content += "\n" + file_content if combined_content else file_content
70
+
71
+ prompt = f"""
72
+ Create a podcast script for {'one person' if num_hosts == 1 else 'two people'} discussing:
73
+ {combined_content}
74
+
75
+ Duration: {duration}. Include natural speech, humor, and occasional off-topic thoughts.
76
+ Use speech fillers like um, ah. Vary emotional tone.
77
+
78
+ Format: {'Monologue' if num_hosts == 1 else 'Alternating dialogue'} without speaker labels.
79
+ Separate {'paragraphs' if num_hosts == 1 else 'lines'} with blank lines.
80
+
81
+ Use emotion tags in angle brackets: <laugh>, <sigh>, <chuckle>, <cough>, <sniffle>, <groan>, <yawn>, <gasp>.
82
+
83
+ Example: "I can't believe I stayed up all night <yawn> only to find out the meeting was canceled <groan>."
84
+
85
+ Ensure content flows naturally and stays on topic. Match the script length to {duration}.
86
+ """
87
+
88
+ response = model.generate_content(prompt)
89
+ return re.sub(r'[^a-zA-Z0-9\s.,?!<>]', '', response.text)
90
+ except Exception as e:
91
+ logger.error(f"Error generating podcast script: {str(e)}")
92
+ raise
93
+
94
+ def process_prompt(prompt, voice, tokenizer, device):
95
+ prompt = f"{voice}: {prompt}"
96
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
97
+
98
+ start_token = torch.tensor([[128259]], dtype=torch.int64)
99
+ end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64)
100
+
101
+ modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
102
+ attention_mask = torch.ones_like(modified_input_ids)
103
+
104
+ return modified_input_ids.to(device), attention_mask.to(device)
105
+
106
+ def parse_output(generated_ids):
107
+ token_to_find = 128257
108
+ token_to_remove = 128258
109
+
110
+ token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
111
+
112
+ if len(token_indices[1]) > 0:
113
+ last_occurrence_idx = token_indices[1][-1].item()
114
+ cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
115
+ else:
116
+ cropped_tensor = generated_ids
117
+
118
+ processed_rows = []
119
+ for row in cropped_tensor:
120
+ masked_row = row[row != token_to_remove]
121
+ processed_rows.append(masked_row)
122
+
123
+ code_lists = []
124
+ for row in processed_rows:
125
+ row_length = row.size(0)
126
+ new_length = (row_length // 7) * 7
127
+ trimmed_row = row[:new_length]
128
+ trimmed_row = [t - 128266 for t in trimmed_row]
129
+ code_lists.append(trimmed_row)
130
+
131
+ return code_lists[0]
132
+
133
+ def redistribute_codes(code_list, snac_model):
134
+ device = next(snac_model.parameters()).device
135
+
136
+ layer_1, layer_2, layer_3 = [], [], []
137
+ for i in range((len(code_list)+1)//7):
138
+ layer_1.append(code_list[7*i])
139
+ layer_2.append(code_list[7*i+1]-4096)
140
+ layer_3.append(code_list[7*i+2]-(2*4096))
141
+ layer_3.append(code_list[7*i+3]-(3*4096))
142
+ layer_2.append(code_list[7*i+4]-(4*4096))
143
+ layer_3.append(code_list[7*i+5]-(5*4096))
144
+ layer_3.append(code_list[7*i+6]-(6*4096))
145
+
146
+ codes = [
147
+ torch.tensor(layer_1, device=device).unsqueeze(0),
148
+ torch.tensor(layer_2, device=device).unsqueeze(0),
149
+ torch.tensor(layer_3, device=device).unsqueeze(0)
150
+ ]
151
+
152
+ audio_hat = snac_model.decode(codes)
153
+ return audio_hat.detach().squeeze().cpu().numpy()
154
+
155
  def text_to_speech(text, voice, temperature=0.6, top_p=0.95, repetition_penalty=1.1, max_new_tokens=1200):
156
  global model, tokenizer, snac_model
157
  if model is None or tokenizer is None or snac_model is None:
 
179
  code_list = parse_output(generated_ids)
180
  audio_samples = redistribute_codes(code_list, snac_model)
181
 
182
+ return (24000, audio_samples)
183
  except Exception as e:
184
  logger.error(f"Error in text_to_speech: {str(e)}")
185
  raise
186
 
 
187
  def render_podcast(api_key, script, voice1, voice2, num_hosts):
188
  try:
189
  lines = [line for line in script.split('\n') if line.strip()]
 
192
  for i, line in enumerate(lines):
193
  voice = voice1 if num_hosts == 1 or i % 2 == 0 else voice2
194
  try:
195
+ result = text_to_speech(line, voice)
196
+ if result is not None:
197
+ sample_rate, audio = result
198
+ audio_segments.append(audio)
199
  except Exception as e:
200
  logger.error(f"Error processing audio segment: {str(e)}")
201
 
 
204
  return (24000, np.zeros(24000, dtype=np.float32))
205
 
206
  podcast_audio = np.concatenate(audio_segments)
 
 
207
  podcast_audio = np.clip(podcast_audio, -1, 1)
208
  podcast_audio = (podcast_audio * 32767).astype(np.int16)
209
 
 
212
  logger.error(f"Error rendering podcast: {str(e)}")
213
  raise
214
 
215
+ with gr.Blocks() as demo:
216
+ gr.Markdown("# AI Podcast Generator")
217
+
218
+ api_key_input = gr.Textbox(label="Enter your Gemini API Key", type="password")
219
+
220
+ with gr.Row():
221
+ content_input = gr.Textbox(label="Paste your content (optional)")
222
+ document_upload = gr.File(label="Upload Document (optional)")
223
+
224
+ duration = gr.Radio(["1-5 min", "5-10 min", "10-15 min"], label="Estimated podcast duration")
225
+ num_hosts = gr.Radio([1, 2], label="Number of podcast hosts", value=2)
226
+
227
+ voice_options = ["tara", "leah", "jess", "leo", "dan", "mia", "zac", "zoe"]
228
+ voice1_select = gr.Dropdown(label="Select Voice 1", choices=voice_options, value="tara")
229
+ voice2_select = gr.Dropdown(label="Select Voice 2", choices=voice_options, value="leo")
230
+
231
+ generate_btn = gr.Button("Generate Script")
232
+ script_output = gr.Textbox(label="Generated Script", lines=10)
233
+
234
+ render_btn = gr.Button("Render Podcast")
235
+ audio_output = gr.Audio(label="Generated Podcast")
236
+
237
+ generate_btn.click(generate_podcast_script,
238
+ inputs=[api_key_input, content_input, document_upload, duration, num_hosts],
239
+ outputs=script_output)
240
+
241
+ render_btn.click(render_podcast,
242
+ inputs=[api_key_input, script_output, voice1_select, voice2_select, num_hosts],
243
+ outputs=audio_output)
244
+
245
+ num_hosts.change(lambda x: gr.update(visible=x == 2),
246
+ inputs=[num_hosts],
247
+ outputs=[voice2_select])
248
 
249
  if __name__ == "__main__":
250
  try:
251
+ load_model()
252
  demo.launch()
253
  except Exception as e:
254
  logger.error(f"Error launching the application: {str(e)}")