saq1b commited on
Commit
d8bebcf
·
verified ·
1 Parent(s): 0d730da

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +166 -76
app.py CHANGED
@@ -4,20 +4,19 @@ from google import genai
4
  from google.genai import types
5
  import json
6
  import uuid
7
- import io
8
  import edge_tts
9
  import asyncio
10
  import aiofiles
11
- import pypdf
12
  import os
13
  import time
 
14
  from typing import List, Dict
15
 
16
  class PodcastGenerator:
17
  def __init__(self):
18
  pass
19
 
20
- async def generate_script(self, prompt: str, language: str, api_key: str) -> Dict:
21
  example = """
22
  {
23
  "topic": "AGI",
@@ -229,47 +228,81 @@ Follow this example structure:
229
  """
230
  user_prompt = f"Please generate a podcast script based on the following user input:\n{prompt}"
231
 
232
- messages = [
233
- {"role": "user", "parts": [user_prompt]}
234
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
  client = genai.Client(api_key=api_key)
237
 
238
  safety_settings = [
239
  {
240
- "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
241
- "threshold": "BLOCK_NONE"
242
  },
243
  {
244
- "category": "HARM_CATEGORY_HARASSMENT",
245
- "threshold": "BLOCK_NONE"
246
  },
247
  {
248
- "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
249
- "threshold": "BLOCK_NONE"
250
  },
251
  {
252
- "category": "HARM_CATEGORY_HATE_SPEECH",
253
- "threshold": "BLOCK_NONE"
254
  }
255
  ]
256
 
257
  try:
258
- response = await client.aio.models.generate_content(
259
- model="gemini-2.0-flash",
260
- contents=messages,
261
- config=types.GenerateContentConfig(
262
- temperature=1,
263
- response_mime_type="application/json",
264
- safety_settings=[
265
- types.SafetySetting(
266
- category=safety_setting["category"],
267
- threshold=safety_setting["threshold"]
268
- ) for safety_setting in safety_settings
269
- ],
270
- system_instruction=system_prompt
271
- )
 
 
 
 
 
 
 
272
  )
 
 
273
  except Exception as e:
274
  if "API key not valid" in str(e):
275
  raise gr.Error("Invalid API key. Please provide a valid Gemini API key.")
@@ -280,7 +313,27 @@ Follow this example structure:
280
 
281
  print(f"Generated podcast script:\n{response.text}")
282
 
 
 
 
283
  return json.loads(response.text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
 
285
  async def tts_generate(self, text: str, speaker: int, speaker1: str, speaker2: str) -> str:
286
  voice = speaker1 if speaker == 1 else speaker2
@@ -288,14 +341,22 @@ Follow this example structure:
288
 
289
  temp_filename = f"temp_{uuid.uuid4()}.wav"
290
  try:
291
- await speech.save(temp_filename)
 
292
  return temp_filename
 
 
 
 
293
  except Exception as e:
294
  if os.path.exists(temp_filename):
295
  os.remove(temp_filename)
296
  raise e
297
 
298
- async def combine_audio_files(self, audio_files: List[str]) -> str:
 
 
 
299
  combined_audio = AudioSegment.empty()
300
  for audio_file in audio_files:
301
  combined_audio += AudioSegment.from_file(audio_file)
@@ -303,44 +364,59 @@ Follow this example structure:
303
 
304
  output_filename = f"output_{uuid.uuid4()}.wav"
305
  combined_audio.export(output_filename, format="wav")
 
 
 
 
306
  return output_filename
307
 
308
- async def generate_podcast(self, input_text: str, language: str, speaker1: str, speaker2: str, api_key: str) -> str:
309
- start_time = time.time()
310
- podcast_json = await self.generate_script(input_text, language, api_key)
311
- end_time = time.time()
312
-
313
- start_time = time.time()
314
- audio_files = await asyncio.gather(*[self.tts_generate(item['line'], item['speaker'], speaker1, speaker2) for item in podcast_json['podcast']])
315
- end_time = time.time()
316
-
317
- combined_audio = await self.combine_audio_files(audio_files)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
  return combined_audio
319
 
320
- class TextExtractor:
321
- @staticmethod
322
- async def extract_from_pdf(file_path: str) -> str:
323
- async with aiofiles.open(file_path, 'rb') as file:
324
- content = await file.read()
325
- pdf_reader = pypdf.PdfReader(io.BytesIO(content))
326
- return "\n\n".join(page.extract_text() for page in pdf_reader.pages if page.extract_text())
327
-
328
- @staticmethod
329
- async def extract_from_txt(file_path: str) -> str:
330
- async with aiofiles.open(file_path, 'r') as file:
331
- return await file.read()
332
-
333
- @classmethod
334
- async def extract_text(cls, file_path: str) -> str:
335
- _, file_extension = os.path.splitext(file_path)
336
- if file_extension.lower() == '.pdf':
337
- return await cls.extract_from_pdf(file_path)
338
- elif file_extension.lower() == '.txt':
339
- return await cls.extract_from_txt(file_path)
340
- else:
341
- raise gr.Error(f"Unsupported file type: {file_extension}")
342
-
343
- async def process_input(input_text: str, input_file, language: str, speaker1: str, speaker2: str, api_key: str = "") -> str:
344
  start_time = time.time()
345
 
346
  voice_names = {
@@ -357,20 +433,32 @@ async def process_input(input_text: str, input_file, language: str, speaker1: st
357
  speaker1 = voice_names[speaker1]
358
  speaker2 = voice_names[speaker2]
359
 
360
- if input_file:
361
- input_text = await TextExtractor.extract_text(input_file.name)
362
 
363
- if not api_key:
364
- api_key = os.getenv("GENAI_API_KEY")
 
 
365
 
366
- podcast_generator = PodcastGenerator()
367
- podcast = await podcast_generator.generate_podcast(input_text, language, speaker1, speaker2, api_key)
368
 
369
- end_time = time.time()
370
-
371
- return podcast
 
 
 
 
 
 
 
 
 
 
372
 
373
- # Define Gradio interface
374
  iface = gr.Interface(
375
  fn=process_input,
376
  inputs=[
@@ -422,8 +510,10 @@ iface = gr.Interface(
422
  ],
423
  title="PodcastGen 🎙️",
424
  description="Generate a 2-speaker podcast from text input or documents!",
425
- allow_flagging="never"
 
 
426
  )
427
 
428
  if __name__ == "__main__":
429
- iface.launch()
 
4
  from google.genai import types
5
  import json
6
  import uuid
 
7
  import edge_tts
8
  import asyncio
9
  import aiofiles
 
10
  import os
11
  import time
12
+ import mimetypes
13
  from typing import List, Dict
14
 
15
  class PodcastGenerator:
16
  def __init__(self):
17
  pass
18
 
19
+ async def generate_script(self, prompt: str, language: str, api_key: str, file_obj=None, progress=None) -> Dict:
20
  example = """
21
  {
22
  "topic": "AGI",
 
228
  """
229
  user_prompt = f"Please generate a podcast script based on the following user input:\n{prompt}"
230
 
231
+ messages = []
232
+
233
+ # If file is provided, add it to the messages
234
+ if file_obj:
235
+ file_data = await self._read_file_bytes(file_obj)
236
+ mime_type = self._get_mime_type(file_obj.name)
237
+
238
+ messages.append(
239
+ types.Content(
240
+ role="user",
241
+ parts=[
242
+ types.Part.from_bytes(
243
+ data=file_data,
244
+ mime_type=mime_type,
245
+ )
246
+ ],
247
+ )
248
+ )
249
+
250
+ # Add text prompt
251
+ messages.append(
252
+ types.Content(
253
+ role="user",
254
+ parts=[
255
+ types.Part.from_text(text=user_prompt)
256
+ ],
257
+ )
258
+ )
259
 
260
  client = genai.Client(api_key=api_key)
261
 
262
  safety_settings = [
263
  {
264
+ "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
265
+ "threshold": "BLOCK_NONE"
266
  },
267
  {
268
+ "category": "HARM_CATEGORY_HARASSMENT",
269
+ "threshold": "BLOCK_NONE"
270
  },
271
  {
272
+ "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
273
+ "threshold": "BLOCK_NONE"
274
  },
275
  {
276
+ "category": "HARM_CATEGORY_HATE_SPEECH",
277
+ "threshold": "BLOCK_NONE"
278
  }
279
  ]
280
 
281
  try:
282
+ if progress:
283
+ progress(0.3, "Generating podcast script...")
284
+
285
+ # Add timeout to the API call
286
+ response = await asyncio.wait_for(
287
+ client.aio.models.generate_content(
288
+ model="gemini-2.0-flash",
289
+ contents=messages,
290
+ config=types.GenerateContentConfig(
291
+ temperature=1,
292
+ response_mime_type="application/json",
293
+ safety_settings=[
294
+ types.SafetySetting(
295
+ category=safety_setting["category"],
296
+ threshold=safety_setting["threshold"]
297
+ ) for safety_setting in safety_settings
298
+ ],
299
+ system_instruction=system_prompt
300
+ )
301
+ ),
302
+ timeout=60 # 60 seconds timeout
303
  )
304
+ except asyncio.TimeoutError:
305
+ raise gr.Error("The script generation request timed out. Please try again later.")
306
  except Exception as e:
307
  if "API key not valid" in str(e):
308
  raise gr.Error("Invalid API key. Please provide a valid Gemini API key.")
 
313
 
314
  print(f"Generated podcast script:\n{response.text}")
315
 
316
+ if progress:
317
+ progress(0.4, "Script generated successfully!")
318
+
319
  return json.loads(response.text)
320
+
321
+ async def _read_file_bytes(self, file_obj) -> bytes:
322
+ """Read file bytes from a file object"""
323
+ async with aiofiles.open(file_obj.name, 'rb') as f:
324
+ return await f.read()
325
+
326
+ def _get_mime_type(self, filename: str) -> str:
327
+ """Determine MIME type based on file extension"""
328
+ ext = os.path.splitext(filename)[1].lower()
329
+ if ext == '.pdf':
330
+ return "application/pdf"
331
+ elif ext == '.txt':
332
+ return "text/plain"
333
+ else:
334
+ # Fallback to the default mime type detector
335
+ mime_type, _ = mimetypes.guess_type(filename)
336
+ return mime_type or "application/octet-stream"
337
 
338
  async def tts_generate(self, text: str, speaker: int, speaker1: str, speaker2: str) -> str:
339
  voice = speaker1 if speaker == 1 else speaker2
 
341
 
342
  temp_filename = f"temp_{uuid.uuid4()}.wav"
343
  try:
344
+ # Add timeout to TTS generation
345
+ await asyncio.wait_for(speech.save(temp_filename), timeout=30) # 30 seconds timeout
346
  return temp_filename
347
+ except asyncio.TimeoutError:
348
+ if os.path.exists(temp_filename):
349
+ os.remove(temp_filename)
350
+ raise gr.Error("Text-to-speech generation timed out. Please try with a shorter text.")
351
  except Exception as e:
352
  if os.path.exists(temp_filename):
353
  os.remove(temp_filename)
354
  raise e
355
 
356
+ async def combine_audio_files(self, audio_files: List[str], progress=None) -> str:
357
+ if progress:
358
+ progress(0.9, "Combining audio files...")
359
+
360
  combined_audio = AudioSegment.empty()
361
  for audio_file in audio_files:
362
  combined_audio += AudioSegment.from_file(audio_file)
 
364
 
365
  output_filename = f"output_{uuid.uuid4()}.wav"
366
  combined_audio.export(output_filename, format="wav")
367
+
368
+ if progress:
369
+ progress(1.0, "Podcast generated successfully!")
370
+
371
  return output_filename
372
 
373
+ async def generate_podcast(self, input_text: str, language: str, speaker1: str, speaker2: str, api_key: str, file_obj=None, progress=None) -> str:
374
+ try:
375
+ if progress:
376
+ progress(0.1, "Starting podcast generation...")
377
+
378
+ # Set overall timeout for the entire process
379
+ return await asyncio.wait_for(
380
+ self._generate_podcast_internal(input_text, language, speaker1, speaker2, api_key, file_obj, progress),
381
+ timeout=600 # 10 minutes total timeout
382
+ )
383
+ except asyncio.TimeoutError:
384
+ raise gr.Error("The podcast generation process timed out. Please try with shorter text or try again later.")
385
+ except Exception as e:
386
+ raise gr.Error(f"Error generating podcast: {str(e)}")
387
+
388
+ async def _generate_podcast_internal(self, input_text: str, language: str, speaker1: str, speaker2: str, api_key: str, file_obj=None, progress=None) -> str:
389
+ if progress:
390
+ progress(0.2, "Generating podcast script...")
391
+
392
+ podcast_json = await self.generate_script(input_text, language, api_key, file_obj, progress)
393
+
394
+ if progress:
395
+ progress(0.5, "Converting text to speech...")
396
+
397
+ # Process TTS in batches to prevent overwhelming the system
398
+ audio_files = []
399
+ total_lines = len(podcast_json['podcast'])
400
+
401
+ for i, item in enumerate(podcast_json['podcast']):
402
+ if progress:
403
+ current_progress = 0.5 + (0.4 * (i / total_lines))
404
+ progress(current_progress, f"Processing speech {i+1}/{total_lines}...")
405
+
406
+ try:
407
+ audio_file = await self.tts_generate(item['line'], item['speaker'], speaker1, speaker2)
408
+ audio_files.append(audio_file)
409
+ except Exception as e:
410
+ # Clean up any files already created
411
+ for file in audio_files:
412
+ if os.path.exists(file):
413
+ os.remove(file)
414
+ raise gr.Error(f"Error generating speech for line {i+1}: {str(e)}")
415
+
416
+ combined_audio = await self.combine_audio_files(audio_files, progress)
417
  return combined_audio
418
 
419
+ async def process_input(input_text: str, input_file, language: str, speaker1: str, speaker2: str, api_key: str = "", progress=gr.Progress()) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
  start_time = time.time()
421
 
422
  voice_names = {
 
433
  speaker1 = voice_names[speaker1]
434
  speaker2 = voice_names[speaker2]
435
 
436
+ try:
437
+ progress(0.05, "Processing input...")
438
 
439
+ if not api_key:
440
+ api_key = os.getenv("GENAI_API_KEY")
441
+ if not api_key:
442
+ raise gr.Error("No API key provided. Please provide a Gemini API key.")
443
 
444
+ podcast_generator = PodcastGenerator()
445
+ podcast = await podcast_generator.generate_podcast(input_text, language, speaker1, speaker2, api_key, input_file, progress)
446
 
447
+ end_time = time.time()
448
+ print(f"Total podcast generation time: {end_time - start_time:.2f} seconds")
449
+ return podcast
450
+
451
+ except Exception as e:
452
+ # Ensure we show a user-friendly error
453
+ error_msg = str(e)
454
+ if "rate limit" in error_msg.lower():
455
+ raise gr.Error("Rate limit exceeded. Please try again later or use your own API key.")
456
+ elif "timeout" in error_msg.lower():
457
+ raise gr.Error("The request timed out. This could be due to server load or the length of your input. Please try again with shorter text.")
458
+ else:
459
+ raise gr.Error(f"Error: {error_msg}")
460
 
461
+ # Define Gradio interface with concurrency control
462
  iface = gr.Interface(
463
  fn=process_input,
464
  inputs=[
 
510
  ],
511
  title="PodcastGen 🎙️",
512
  description="Generate a 2-speaker podcast from text input or documents!",
513
+ allow_flagging="never",
514
+ concurrency_limit=3, # Limit concurrent requests to prevent overload
515
+ concurrency_id="podcast_gen" # Identifier for concurrency group
516
  )
517
 
518
  if __name__ == "__main__":
519
+ iface.queue(max_size=10).launch() # Set maximum queue size