Rausda6 commited on
Commit
132e1a9
·
verified ·
1 Parent(s): e8c85bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -370
app.py CHANGED
@@ -15,15 +15,14 @@ from typing import List
15
 
16
  from PyPDF2 import PdfReader
17
 
18
-
19
  # Define model name clearly
20
- MODEL_NAME = "unsloth/gemma-3-1b-pt" # HuggingFaceH4/zephyr-7b-alpha
21
 
22
  # Device setup
23
  device = "cuda" if torch.cuda.is_available() else "cpu"
24
  print(f"Using device: {device}")
25
 
26
- # Load model and tokenizer (explicit evaluation mode)
27
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
28
  model = AutoModelForCausalLM.from_pretrained(
29
  MODEL_NAME,
@@ -32,7 +31,7 @@ model = AutoModelForCausalLM.from_pretrained(
32
 
33
  # Constants
34
  MAX_FILE_SIZE_MB = 20
35
- MAX_FILE_SIZE_BYTES = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
36
 
37
  class PodcastGenerator:
38
  def __init__(self):
@@ -40,63 +39,19 @@ class PodcastGenerator:
40
 
41
  async def generate_script(self, prompt: str, language: str, api_key: str, file_obj=None, progress=None):
42
  example = """
43
- {
44
- "topic": "AGI",
45
- "podcast": [
46
- {
47
- "speaker": 2,
48
- "line": "So, AGI, huh? Seems like everyone's talking about it these days."
49
- },
50
- {
51
- "speaker": 1,
52
- "line": "Yeah, it's definitely having a moment, isn't it?"
53
- },
54
- {
55
- "speaker": 2,
56
- "line": "It is and for good reason, right? I mean, you've been digging into this stuff, listening to the podcasts and everything. What really stood out to you? What got you hooked?"
57
- },
58
- {
59
- "speaker": 1,
60
- "line": "It's easy to get lost in the noise, for sure."
61
- },
62
- {
63
- "speaker": 2,
64
- "line": "Exactly. So how about we try to cut through some of that, shall we?"
65
- },
66
- {
67
- "speaker": 1,
68
- "line": "Sounds like a plan."
69
- },
70
- {
71
- "speaker": 2,
72
- "line": "It certainly is and on that note, we'll wrap up this deep dive. Thanks for listening, everyone."
73
- },
74
- {
75
- "speaker": 1,
76
- "line": "Peace."
77
- }
78
- ]
79
- }
80
  """
81
-
82
  if language == "Auto Detect":
83
  language_instruction = "- The podcast MUST be in the same language as the user input."
84
  else:
85
  language_instruction = f"- The podcast MUST be in {language} language"
86
 
87
  system_prompt = f"""
88
- You are a professional podcast generator. Your task is to generate a professional podcast script based on the user input.
89
  {language_instruction}
90
- - The podcast should have 2 speakers.
91
- - The podcast should be long.
92
- - Do not use names for the speakers.
93
- - The podcast should be interesting, lively, and engaging, and hook the listener from the start.
94
- - The input text might be disorganized or unformatted, originating from sources like PDFs or text files. Ignore any formatting inconsistencies or irrelevant details; your task is to distill the essential points, identify key definitions, and highlight intriguing facts that would be suitable for discussion in a podcast.
95
- - The script must be in JSON format.
96
  Follow this example structure:
97
  {example}
98
  """
99
- # Build the user prompt
100
  if prompt and file_obj:
101
  user_prompt = f"Please generate a podcast script based on the uploaded file following user input:\n{prompt}"
102
  elif prompt:
@@ -104,344 +59,68 @@ Follow this example structure:
104
  else:
105
  user_prompt = "Please generate a podcast script based on the uploaded file."
106
 
107
- # If a file is provided, extract its text and append
108
  if file_obj:
109
- # enforce size limit
110
  file_size = getattr(file_obj, 'size', os.path.getsize(file_obj.name))
111
  if file_size > MAX_FILE_SIZE_BYTES:
112
- raise Exception(f"File size exceeds the {MAX_FILE_SIZE_MB}MB limit. Please upload a smaller file.")
113
-
114
- # extract text based on mime
115
  ext = os.path.splitext(file_obj.name)[1].lower()
116
  if ext == '.pdf':
117
  reader = PdfReader(file_obj)
118
  text = "\n\n".join(page.extract_text() or '' for page in reader.pages)
119
  else:
120
- # txt or other
121
- if hasattr(file_obj, 'read'):
122
- raw = file_obj.read()
123
- else:
124
- raw = await aiofiles.open(file_obj.name, 'rb').read()
125
  text = raw.decode(errors='ignore')
126
-
127
  user_prompt += f"\n\n―― FILE CONTENT ――\n{text}"
128
 
129
- # Combine system and user prompts
130
  prompt_text = system_prompt + "\n" + user_prompt
131
 
132
  try:
133
- if progress:
134
- progress(0.3, "Generating podcast script...")
135
-
136
- def hf_generate(prompt_text):
137
- inputs = tokenizer(prompt_text, return_tensors="pt").to(model.device)
138
- outputs = model.generate(
139
- **inputs,
140
- max_new_tokens=1024,
141
- do_sample=True,
142
- temperature=1.0
143
- )
144
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
145
-
146
- generated_text = await asyncio.wait_for(
147
- asyncio.to_thread(hf_generate, prompt_text),
148
- timeout=60
149
- )
150
-
151
  except asyncio.TimeoutError:
152
- raise Exception("The script generation request timed out. Please try again later.")
153
  except Exception as e:
154
- raise Exception(f"Failed to generate podcast script: {e}")
155
-
156
- if progress:
157
- progress(0.4, "Script generated successfully!")
158
-
159
  return json.loads(generated_text)
160
 
161
- # ... rest of class unchanged ...
162
-
163
-
164
- # ... rest of class unchanged ...
165
 
166
-
167
- # ... rest of class unchanged ...
168
-
169
-
170
-
171
- async def _read_file_bytes(self, file_obj) -> bytes:
172
- """Read file bytes from a file object"""
173
- # Check file size before reading
174
- if hasattr(file_obj, 'size'):
175
- file_size = file_obj.size
176
- else:
177
- file_size = os.path.getsize(file_obj.name)
178
-
179
- if file_size > MAX_FILE_SIZE_BYTES:
180
- raise Exception(f"File size exceeds the {MAX_FILE_SIZE_MB}MB limit. Please upload a smaller file.")
181
-
182
- if hasattr(file_obj, 'read'):
183
- return file_obj.read()
184
- else:
185
- async with aiofiles.open(file_obj.name, 'rb') as f:
186
- return await f.read()
187
-
188
- def _get_mime_type(self, filename: str) -> str:
189
- """Determine MIME type based on file extension"""
190
- ext = os.path.splitext(filename)[1].lower()
191
- if ext == '.pdf':
192
- return "application/pdf"
193
- elif ext == '.txt':
194
- return "text/plain"
195
- else:
196
- # Fallback to the default mime type detector
197
- mime_type, _ = mimetypes.guess_type(filename)
198
- return mime_type or "application/octet-stream"
199
-
200
- async def tts_generate(self, text: str, speaker: int, speaker1: str, speaker2: str) -> str:
201
- voice = speaker1 if speaker == 1 else speaker2
202
- speech = edge_tts.Communicate(text, voice)
203
-
204
- temp_filename = f"temp_{uuid.uuid4()}.wav"
205
- try:
206
- # Add timeout to TTS generation
207
- await asyncio.wait_for(speech.save(temp_filename), timeout=30) # 30 seconds timeout
208
- return temp_filename
209
- except asyncio.TimeoutError:
210
- if os.path.exists(temp_filename):
211
- os.remove(temp_filename)
212
- raise Exception("Text-to-speech generation timed out. Please try with a shorter text.")
213
- except Exception as e:
214
- if os.path.exists(temp_filename):
215
- os.remove(temp_filename)
216
- raise e
217
-
218
- async def combine_audio_files(self, audio_files: List[str], progress=None) -> str:
219
- if progress:
220
- progress(0.9, "Combining audio files...")
221
-
222
- combined_audio = AudioSegment.empty()
223
- for audio_file in audio_files:
224
- combined_audio += AudioSegment.from_file(audio_file)
225
- os.remove(audio_file) # Clean up temporary files
226
-
227
- output_filename = f"output_{uuid.uuid4()}.wav"
228
- combined_audio.export(output_filename, format="wav")
229
-
230
- if progress:
231
- progress(1.0, "Podcast generated successfully!")
232
-
233
- return output_filename
234
-
235
- async def generate_podcast(self, input_text: str, language: str, speaker1: str, speaker2: str, api_key: str, file_obj=None, progress=None) -> str:
236
- try:
237
- if progress:
238
- progress(0.1, "Starting podcast generation...")
239
-
240
- # Set overall timeout for the entire process
241
- return await asyncio.wait_for(
242
- self._generate_podcast_internal(input_text, language, speaker1, speaker2, api_key, file_obj, progress),
243
- timeout=600 # 10 minutes total timeout
244
- )
245
- except asyncio.TimeoutError:
246
- raise Exception("The podcast generation process timed out. Please try with shorter text or try again later.")
247
- except Exception as e:
248
- raise Exception(f"Error generating podcast: {str(e)}")
249
-
250
- async def _generate_podcast_internal(self, input_text: str, language: str, speaker1: str, speaker2: str, api_key: str, file_obj=None, progress=None) -> str:
251
- if progress:
252
- progress(0.2, "Generating podcast script...")
253
-
254
- podcast_json = await self.generate_script(input_text, language, api_key, file_obj, progress)
255
-
256
- if progress:
257
- progress(0.5, "Converting text to speech...")
258
-
259
- # Process TTS in batches for concurrent processing
260
- audio_files = []
261
- total_lines = len(podcast_json['podcast'])
262
-
263
- # Define batch size to control concurrency
264
- batch_size = 10 # Adjust based on system resources
265
-
266
- # Process in batches
267
- for batch_start in range(0, total_lines, batch_size):
268
- batch_end = min(batch_start + batch_size, total_lines)
269
- batch = podcast_json['podcast'][batch_start:batch_end]
270
-
271
- # Create tasks for concurrent processing
272
- tts_tasks = []
273
- for item in batch:
274
- tts_task = self.tts_generate(item['line'], item['speaker'], speaker1, speaker2)
275
- tts_tasks.append(tts_task)
276
-
277
- try:
278
- # Process batch concurrently
279
- batch_results = await asyncio.gather(*tts_tasks, return_exceptions=True)
280
-
281
- # Check for exceptions and handle results
282
- for i, result in enumerate(batch_results):
283
- if isinstance(result, Exception):
284
- # Clean up any files already created
285
- for file in audio_files:
286
- if os.path.exists(file):
287
- os.remove(file)
288
- raise Exception(f"Error generating speech: {str(result)}")
289
- else:
290
- audio_files.append(result)
291
-
292
- # Update progress
293
- if progress:
294
- current_progress = 0.5 + (0.4 * (batch_end / total_lines))
295
- progress(current_progress, f"Processed {batch_end}/{total_lines} speech segments...")
296
-
297
- except Exception as e:
298
- # Clean up any files already created
299
- for file in audio_files:
300
- if os.path.exists(file):
301
- os.remove(file)
302
- raise Exception(f"Error in batch TTS generation: {str(e)}")
303
-
304
- combined_audio = await self.combine_audio_files(audio_files, progress)
305
- return combined_audio
306
-
307
- async def process_input(input_text: str, input_file, language: str, speaker1: str, speaker2: str, api_key: str = "", progress=None) -> str:
308
- start_time = time.time()
309
-
310
- voice_names = {
311
- "Andrew - English (United States)": "en-US-AndrewMultilingualNeural",
312
- "Ava - English (United States)": "en-US-AvaMultilingualNeural",
313
- "Brian - English (United States)": "en-US-BrianMultilingualNeural",
314
- "Emma - English (United States)": "en-US-EmmaMultilingualNeural",
315
- "Florian - German (Germany)": "de-DE-FlorianMultilingualNeural",
316
- "Seraphina - German (Germany)": "de-DE-SeraphinaMultilingualNeural",
317
- "Remy - French (France)": "fr-FR-RemyMultilingualNeural",
318
- "Vivienne - French (France)": "fr-FR-VivienneMultilingualNeural"
319
- }
320
-
321
- speaker1 = voice_names[speaker1]
322
- speaker2 = voice_names[speaker2]
323
-
324
- try:
325
- if progress:
326
- progress(0.05, "Processing input...")
327
-
328
- if not api_key:
329
- api_key = "saf" # os.getenv("GENAI_API_KEY")
330
- if not api_key:
331
- raise Exception("No API key provided. Please provide a Gemini API key.")
332
-
333
- podcast_generator = PodcastGenerator()
334
- podcast = await podcast_generator.generate_podcast(input_text, language, speaker1, speaker2, api_key, input_file, progress)
335
-
336
- end_time = time.time()
337
- print(f"Total podcast generation time: {end_time - start_time:.2f} seconds")
338
- return podcast
339
-
340
- except Exception as e:
341
- # Ensure we show a user-friendly error
342
- error_msg = str(e)
343
- if "rate limit" in error_msg.lower():
344
- raise Exception("Rate limit exceeded. Please try again later or use your own API key.")
345
- elif "timeout" in error_msg.lower():
346
- raise Exception("The request timed out. This could be due to server load or the length of your input. Please try again with shorter text.")
347
- else:
348
- raise Exception(f"Error: {error_msg}")
349
 
350
  # Gradio UI
351
- def generate_podcast_gradio(input_text, input_file, language, speaker1, speaker2, api_key):
352
- # Handle the file if uploaded
353
- file_obj = input_file if input_file is not None else None
354
- try:
355
- # Run the async function in the event loop
356
- return asyncio.run(process_input(
357
- input_text,
358
- file_obj,
359
- language,
360
- speaker1,
361
- speaker2,
362
- api_key,
363
- # internally process_input still accepts a progress callback
364
- # but since we're using Gradio's built-in bar, just pass a no-op:
365
- lambda *_: None
366
- ))
367
- except Exception as e:
368
- raise gr.Error(str(e))
369
-
370
-
371
- def main():
372
- with gr.Blocks(title="PodcastGen 🎙️") as demo:
373
- gr.Markdown(
374
- """
375
- # PodcastGen 🎙️
376
- Generate a 2-speaker podcast from text or PDF!
377
- """
378
- )
379
- with gr.Row():
380
- with gr.Column():
381
- input_text = gr.Textbox(label="Input Text", lines=10, placeholder="Enter podcast topic or paste text here...", elem_id="input_text")
382
- input_file = gr.File(label="Or Upload a PDF or TXT file", file_types=[".pdf", ".txt"])
383
- with gr.Column():
384
- language = gr.Dropdown(
385
- label="Podcast Language",
386
- choices=[
387
- "Auto Detect",
388
- "English",
389
- "German",
390
- "French",
391
- "Spanish",
392
- "Italian",
393
- "Dutch",
394
- "Portuguese",
395
- "Russian",
396
- "Chinese",
397
- "Japanese",
398
- "Korean",
399
- "Other",
400
- ],
401
- value="Auto Detect"
402
- )
403
- speaker1 = gr.Dropdown(
404
- label="Speaker 1 Voice",
405
- choices=[
406
- "Andrew - English (United States)",
407
- "Ava - English (United States)",
408
- "Brian - English (United States)",
409
- "Emma - English (United States)",
410
- "Florian - German (Germany)",
411
- "Seraphina - German (Germany)",
412
- "Remy - French (France)",
413
- "Vivienne - French (France)"
414
- ],
415
- value="Andrew - English (United States)",
416
- )
417
- speaker2 = gr.Dropdown(
418
- label="Speaker 2 Voice",
419
- choices=[
420
- "Andrew - English (United States)",
421
- "Ava - English (United States)",
422
- "Brian - English (United States)",
423
- "Emma - English (United States)",
424
- "Florian - German (Germany)",
425
- "Seraphina - German (Germany)",
426
- "Remy - French (France)",
427
- "Vivienne - French (France)"
428
- ],
429
- value="Ava - English (United States)",
430
- )
431
- api_key = gr.Textbox(label="Gemini API Key (Optional)", type="password", placeholder="Needed only if you're getting rate limited.")
432
-
433
- generate_btn = gr.Button("Generate Podcast 🎙️", variant="primary")
434
- output_audio = gr.Audio(label="Generated Podcast", type="filepath", format="wav", elem_id="output_audio")
435
-
436
- generate_btn.click(
437
- fn=generate_podcast_gradio,
438
- inputs=[input_text, input_file, language, speaker1, speaker2, api_key],
439
- outputs=output_audio,
440
- show_progress=True
441
- )
442
-
443
- demo.queue()
444
- demo.launch(server_name="0.0.0.0", debug=True)
445
-
446
- if __name__ == "__main__":
447
- main()
 
15
 
16
  from PyPDF2 import PdfReader
17
 
 
18
  # Define model name clearly
19
+ MODEL_NAME = "unsloth/gemma-3-1b-pt"
20
 
21
  # Device setup
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
  print(f"Using device: {device}")
24
 
25
+ # Load model and tokenizer
26
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
27
  model = AutoModelForCausalLM.from_pretrained(
28
  MODEL_NAME,
 
31
 
32
  # Constants
33
  MAX_FILE_SIZE_MB = 20
34
+ MAX_FILE_SIZE_BYTES = MAX_FILE_SIZE_MB * 1024 * 1024
35
 
36
  class PodcastGenerator:
37
  def __init__(self):
 
39
 
40
  async def generate_script(self, prompt: str, language: str, api_key: str, file_obj=None, progress=None):
41
  example = """
42
+ {...}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  """
 
44
  if language == "Auto Detect":
45
  language_instruction = "- The podcast MUST be in the same language as the user input."
46
  else:
47
  language_instruction = f"- The podcast MUST be in {language} language"
48
 
49
  system_prompt = f"""
50
+ You are a professional podcast generator...
51
  {language_instruction}
 
 
 
 
 
 
52
  Follow this example structure:
53
  {example}
54
  """
 
55
  if prompt and file_obj:
56
  user_prompt = f"Please generate a podcast script based on the uploaded file following user input:\n{prompt}"
57
  elif prompt:
 
59
  else:
60
  user_prompt = "Please generate a podcast script based on the uploaded file."
61
 
 
62
  if file_obj:
 
63
  file_size = getattr(file_obj, 'size', os.path.getsize(file_obj.name))
64
  if file_size > MAX_FILE_SIZE_BYTES:
65
+ raise Exception("File size exceeds limit.")
 
 
66
  ext = os.path.splitext(file_obj.name)[1].lower()
67
  if ext == '.pdf':
68
  reader = PdfReader(file_obj)
69
  text = "\n\n".join(page.extract_text() or '' for page in reader.pages)
70
  else:
71
+ raw = file_obj.read() if hasattr(file_obj, 'read') else await aiofiles.open(file_obj.name, 'rb').read()
 
 
 
 
72
  text = raw.decode(errors='ignore')
 
73
  user_prompt += f"\n\n―― FILE CONTENT ――\n{text}"
74
 
 
75
  prompt_text = system_prompt + "\n" + user_prompt
76
 
77
  try:
78
+ if progress: progress(0.3, "Generating podcast script...")
79
+ def hf_generate(p):
80
+ inputs = tokenizer(p, return_tensors="pt").to(model.device)
81
+ outs = model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=1.0)
82
+ return tokenizer.decode(outs[0], skip_special_tokens=True)
83
+ generated_text = await asyncio.wait_for(asyncio.to_thread(hf_generate, prompt_text), timeout=60)
 
 
 
 
 
 
 
 
 
 
 
 
84
  except asyncio.TimeoutError:
85
+ raise Exception("Script generation timed out.")
86
  except Exception as e:
87
+ raise Exception(f"Failed to generate script: {e}")
88
+ if progress: progress(0.4, "Script generated successfully!")
 
 
 
89
  return json.loads(generated_text)
90
 
91
+ # ... TTS and combine_audio_files methods unchanged ...
 
 
 
92
 
93
+ async def process_input(input_text, input_file, language, speaker1, speaker2, api_key="", progress=None):
94
+ # Implementation unchanged
95
+ ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  # Gradio UI
98
+ with gr.Blocks(title="PodcastGen 🎙️") as demo:
99
+ gr.Markdown("""
100
+ # PodcastGen 🎙️
101
+ Generate a 2-speaker podcast from text or PDF!
102
+ """
103
+ )
104
+ with gr.Row():
105
+ with gr.Column():
106
+ input_text = gr.Textbox(...)
107
+ input_file = gr.File(...)
108
+ with gr.Column():
109
+ language = gr.Dropdown(...)
110
+ speaker1 = gr.Dropdown(...)
111
+ speaker2 = gr.Dropdown(...)
112
+ api_key = gr.Textbox(...)
113
+
114
+ generate_btn = gr.Button("Generate Podcast 🎙️", variant="primary")
115
+ output_audio = gr.Audio(...)
116
+
117
+ # Bind async function directly
118
+ generate_btn.click(
119
+ fn=process_input,
120
+ inputs=[input_text, input_file, language, speaker1, speaker2, api_key],
121
+ outputs=output_audio,
122
+ show_progress=True
123
+ )
124
+
125
+ demo.queue()
126
+ demo.launch(server_name="0.0.0.0", share=True, debug=True)