PuristanLabs1 commited on
Commit
392490b
·
verified ·
1 Parent(s): b097fde

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -45
app.py CHANGED
@@ -19,17 +19,18 @@ from wordcloud import WordCloud
19
  import matplotlib.pyplot as plt
20
  from PIL import Image
21
  import io
 
22
  from gliner import GLiNER
23
  import tempfile
24
 
25
  nltk.download("punkt")
26
  nltk.download("punkt_tab")
27
 
28
- stanza.download("en")
29
- nlp = stanza.Pipeline("en", processors="tokenize,ner", use_gpu=False)
 
30
 
31
-
32
- kokoro_tts = KPipeline(lang_code='a', device="cpu")
33
 
34
  # Supported TTS Languages
35
  SUPPORTED_TTS_LANGUAGES = {
@@ -48,28 +49,73 @@ AVAILABLE_VOICES = [
48
 
49
  # Load BART Large CNN Model for Summarization
50
  model_name = "facebook/bart-large-cnn"
51
- tokenizer = BartTokenizer.from_pretrained(model_name)
52
- model = BartForConditionalGeneration.from_pretrained(model_name)
 
 
 
 
 
 
53
 
54
  # Initialize GLINER model
55
  gliner_model = GLiNER.from_pretrained("urchade/gliner_base")
56
 
57
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  def fetch_and_display_content(url):
59
- """Fetch and extract text from a given URL (HTML or PDF)."""
60
- if url.endswith(".pdf") or "pdf" in url:
61
- converter = MarkItDown()
62
-
63
- text = converter.convert(url).text_content
 
 
 
 
 
 
 
 
 
 
64
  else:
65
- downloaded = trafilatura.fetch_url(url)
66
- text = extract(downloaded, output_format="markdown", with_metadata=True, include_tables=False, include_links=False, include_formatting=True, include_comments=False) #without metadata extraction
 
 
 
67
  metadata, cleaned_text = extract_and_clean_text(text)
68
  detected_lang = detect_language(cleaned_text)
69
 
70
  # Add detected language to metadata
71
  metadata["Detected Language"] = detected_lang.upper()
72
- return cleaned_text, metadata, detected_lang, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
 
 
 
 
 
 
 
 
 
 
73
 
74
 
75
  def extract_and_clean_text(data):
@@ -171,13 +217,11 @@ def extract_entities_with_stanza(text, chunk_size=1000):
171
 
172
  def generate_wordcloud(text):
173
 
174
- if not text:
175
- return None
176
-
177
-
178
  wordcloud = WordCloud(width=800, height=400, background_color='white').generate(text)
179
-
180
-
181
  plt.figure(figsize=(10, 5))
182
  plt.imshow(wordcloud, interpolation='bilinear')
183
  plt.axis('off')
@@ -196,11 +240,8 @@ def generate_wordcloud(text):
196
  def generate_audio_kokoro(text, lang, selected_voice):
197
  """Generate speech using KokoroTTS for supported languages."""
198
  global kokoro_tts
199
- if os.path.exists(f"audio_{lang}.wav"):
200
- os.remove(f"audio_{lang}.wav")
201
 
202
  lang_code = SUPPORTED_TTS_LANGUAGES.get(lang, "a") # Default to English
203
- #generator = kokoro_tts(text, voice="bm_george", speed=1, split_pattern=r'\n+')
204
  generator = kokoro_tts(text, voice=selected_voice, speed=1, split_pattern=r'\n+')
205
 
206
 
@@ -242,26 +283,90 @@ def summarize_text(text, max_input_tokens=1024, max_output_tokens=200):
242
  inputs = tokenizer.encode("summarize: " + text, return_tensors="pt", max_length=max_input_tokens, truncation=True)
243
  summary_ids = model.generate(inputs, max_length=max_output_tokens, min_length=50, length_penalty=2.0, num_beams=4, early_stopping=True)
244
  return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
 
245
  def hierarchical_summarization(text):
246
-
 
 
 
 
 
247
  chunks = split_text_with_optimized_overlap(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
 
249
- chunk_summaries = [summarize_text(chunk) for chunk in chunks]
 
250
  final_summary = " ".join(chunk_summaries)
251
- return final_summary
252
 
253
- def extract_entities_with_gliner(text, default_entity_types, custom_entity_types):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  """
255
- Extract entities using GLINER with default and custom entity types.
256
  """
257
-
258
- entity_types = default_entity_types.split(",") + [etype.strip() for etype in custom_entity_types.split(",") if custom_entity_types]
259
-
 
260
  entity_types = list(set([etype.strip() for etype in entity_types if etype.strip()]))
261
-
262
- entities = gliner_model.predict_entities(text, entity_types)
263
-
264
- formatted_entities = "\n".join([f"{i+1}: {ent['text']} --> {ent['label']}" for i, ent in enumerate(entities)])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  return formatted_entities
266
 
267
  ### 5️⃣ Main Processing Function
@@ -298,9 +403,7 @@ with gr.Blocks() as demo:
298
  summary_output = gr.Textbox(label="Summary", visible=True, interactive=False)
299
  full_audio_output = gr.Audio(label="Generated Audio", visible=True)
300
  ner_output = gr.Textbox(label="Extracted Entities", visible=True, interactive=False)
301
-
302
-
303
-
304
  default_entity_types = gr.Textbox(label="Default Entity Types", value="PERSON, Organization, location, Date, PRODUCT, EVENT", interactive=True)
305
  custom_entity_types = gr.Textbox(label="Custom Entity Types", placeholder="Enter additional entity types (comma-separated)", interactive=True)
306
 
@@ -320,14 +423,20 @@ with gr.Blocks() as demo:
320
  show_progress=True
321
  )
322
 
 
 
 
 
 
 
323
  extracted_text.change(
324
- hierarchical_summarization,
325
  inputs=[extracted_text],
326
- outputs=[summary_output],
327
  show_progress=True
328
  )
329
 
330
-
331
  process_audio_button.click(
332
  lambda text, summary, lang, voice, tts_choice: (
333
  None, # Clear previous audio
@@ -340,13 +449,13 @@ with gr.Blocks() as demo:
340
  show_progress=True
341
  )
342
 
343
-
344
  process_ner_button.click(
345
 
346
  extract_entities_with_gliner,
347
 
348
- inputs=[extracted_text, default_entity_types, custom_entity_types],
349
  outputs=[ner_output]
350
  )
351
 
352
- demo.launch()
 
19
  import matplotlib.pyplot as plt
20
  from PIL import Image
21
  import io
22
+ import requests
23
  from gliner import GLiNER
24
  import tempfile
25
 
26
  nltk.download("punkt")
27
  nltk.download("punkt_tab")
28
 
29
+ # Automatically select device based on hardware availability
30
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
31
+ print(f"✅ Using {DEVICE.upper()} for TTS and Summarization")
32
 
33
+ kokoro_tts = KPipeline(lang_code='a', device=DEVICE)
 
34
 
35
  # Supported TTS Languages
36
  SUPPORTED_TTS_LANGUAGES = {
 
49
 
50
  # Load BART Large CNN Model for Summarization
51
  model_name = "facebook/bart-large-cnn"
52
+
53
+
54
+ try:
55
+ tokenizer = BartTokenizer.from_pretrained(model_name, cache_dir=os.path.join(os.getcwd(), ".cache"))
56
+ model = BartForConditionalGeneration.from_pretrained(model_name, cache_dir=os.path.join(os.getcwd(), ".cache")).to(DEVICE)
57
+
58
+ except Exception as e:
59
+ raise RuntimeError(f"Error loading BART model: {e}")
60
 
61
  # Initialize GLINER model
62
  gliner_model = GLiNER.from_pretrained("urchade/gliner_base")
63
 
64
+ def is_pdf_url(url):
65
+ """Robustly detects PDF files via URL patterns and Content-Type headers."""
66
+ # URL Pattern Check
67
+ if url.endswith(".pdf") or "pdf" in url.lower():
68
+ return True
69
+
70
+ # Check Content-Type Header (for URLs without '.pdf')
71
+ try:
72
+ response = requests.head(url, timeout=10)
73
+ content_type = response.headers.get('Content-Type', '')
74
+ if 'application/pdf' in content_type:
75
+ return True
76
+ except requests.RequestException:
77
+ pass # Ignore errors in Content-Type check
78
+
79
+ return False
80
+
81
  def fetch_and_display_content(url):
82
+ """
83
+ Fetch and extract text from a given URL (HTML or PDF).
84
+ Extract metadata, clean text, and detect language.
85
+ """
86
+
87
+ downloaded = trafilatura.fetch_url(url)
88
+ if not downloaded:
89
+ raise ValueError(f"❌ Failed to fetch content from URL: {url}")
90
+
91
+ if is_pdf_url(url):
92
+ converter = MarkItDown(enable_plugins=False)
93
+ try:
94
+ text = converter.convert(url).text_content
95
+ except Exception as e:
96
+ raise RuntimeError(f"❌ Error converting PDF with MarkItDown: {e}")
97
  else:
98
+ text = extract(downloaded, output_format="markdown", with_metadata=True, include_tables=False, include_links=False, include_formatting=True, include_comments=False)
99
+
100
+ if not text or len(text.strip()) == 0:
101
+ raise ValueError("❌ No content found in the extracted data.")
102
+
103
  metadata, cleaned_text = extract_and_clean_text(text)
104
  detected_lang = detect_language(cleaned_text)
105
 
106
  # Add detected language to metadata
107
  metadata["Detected Language"] = detected_lang.upper()
108
+
109
+ return (
110
+ cleaned_text,
111
+ metadata,
112
+ detected_lang,
113
+ gr.update(visible=True), # Show Word Cloud
114
+ gr.update(visible=True), # Show Process Audio Button
115
+ gr.update(visible=True), # Show Process NER Button
116
+ gr.update(visible=True), # Show Extracted Text
117
+ gr.update(visible=True) # Show Metadata Output
118
+ )
119
 
120
 
121
  def extract_and_clean_text(data):
 
217
 
218
  def generate_wordcloud(text):
219
 
220
+ if not text.strip():
221
+ raise ValueError("❌ Text is empty or invalid for WordCloud generation.")
222
+
 
223
  wordcloud = WordCloud(width=800, height=400, background_color='white').generate(text)
224
+
 
225
  plt.figure(figsize=(10, 5))
226
  plt.imshow(wordcloud, interpolation='bilinear')
227
  plt.axis('off')
 
240
  def generate_audio_kokoro(text, lang, selected_voice):
241
  """Generate speech using KokoroTTS for supported languages."""
242
  global kokoro_tts
 
 
243
 
244
  lang_code = SUPPORTED_TTS_LANGUAGES.get(lang, "a") # Default to English
 
245
  generator = kokoro_tts(text, voice=selected_voice, speed=1, split_pattern=r'\n+')
246
 
247
 
 
283
  inputs = tokenizer.encode("summarize: " + text, return_tensors="pt", max_length=max_input_tokens, truncation=True)
284
  summary_ids = model.generate(inputs, max_length=max_output_tokens, min_length=50, length_penalty=2.0, num_beams=4, early_stopping=True)
285
  return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
286
+
287
  def hierarchical_summarization(text):
288
+ """Performs hierarchical summarization by chunking content first."""
289
+ print(f"✅ Summarization will run on: {DEVICE.upper()}")
290
+
291
+ if len(text) > 10000:
292
+ print("⚠️ Warning: Large input text detected. Summarization may take longer than usual.")
293
+
294
  chunks = split_text_with_optimized_overlap(text)
295
+ #Tokenize the input cleaned text
296
+ encoded_inputs = tokenizer(
297
+ ["summarize: " + chunk for chunk in chunks],
298
+ return_tensors="pt",
299
+ padding=True,
300
+ truncation=True,
301
+ max_length=1024
302
+ ).to(DEVICE)
303
+
304
+ #Generate the summary
305
+ summary_ids = model.generate(
306
+ encoded_inputs["input_ids"],
307
+ max_length=200,
308
+ min_length=50,
309
+ length_penalty=2.0,
310
+ num_beams=4,
311
+ early_stopping=True
312
+ )
313
 
314
+ #decode the summary generated in above step
315
+ chunk_summaries = [tokenizer.decode(ids, skip_special_tokens=True) for ids in summary_ids]
316
  final_summary = " ".join(chunk_summaries)
317
+ return final_summary
318
 
319
+ def chunk_text_with_overlap(text, max_tokens=500, overlap_tokens=50):
320
+ """Splits text into overlapping chunks for large document processing."""
321
+ sentences = re.split(r'(?<=[.!?])\s+', text) # Split on sentence boundaries
322
+ chunks = []
323
+ current_chunk = []
324
+ current_length = 0
325
+ previous_chunk_text = ""
326
+
327
+ for sentence in sentences:
328
+ token_length = len(sentence.split())
329
+ if current_length + token_length > max_tokens:
330
+ chunks.append(previous_chunk_text + " " + " ".join(current_chunk))
331
+ previous_chunk_text = " ".join(current_chunk)[-overlap_tokens:]
332
+ current_chunk = [sentence]
333
+ current_length = token_length
334
+ else:
335
+ current_chunk.append(sentence)
336
+ current_length += token_length
337
+
338
+ if current_chunk:
339
+ chunks.append(previous_chunk_text + " " + " ".join(current_chunk))
340
+
341
+ return chunks
342
+
343
+ def extract_entities_with_gliner(text, default_entity_types, custom_entity_types, batch_size=4):
344
  """
345
+ Extract entities using GLINER with efficient chunking, sliding window, and batching.
346
  """
347
+ # Entity types preparation
348
+ entity_types = default_entity_types.split(",") + [
349
+ etype.strip() for etype in custom_entity_types.split(",") if custom_entity_types
350
+ ]
351
  entity_types = list(set([etype.strip() for etype in entity_types if etype.strip()]))
352
+
353
+ # Chunk the text to avoid overflow
354
+ chunks = chunk_text_with_overlap(text)
355
+
356
+ # Process each chunk individually for improved stability
357
+ all_entities = []
358
+ for i, chunk in enumerate(chunks):
359
+ try:
360
+ entities = gliner_model.predict_entities(chunk, entity_types)
361
+ all_entities.extend(entities)
362
+ except Exception as e:
363
+ print(f"⚠️ Error processing chunk {i}: {e}")
364
+
365
+ # Format the results
366
+ formatted_entities = "\n".join(
367
+ [f"{i+1}: {ent['text']} --> {ent['label']}" for i, ent in enumerate(all_entities)]
368
+ )
369
+
370
  return formatted_entities
371
 
372
  ### 5️⃣ Main Processing Function
 
403
  summary_output = gr.Textbox(label="Summary", visible=True, interactive=False)
404
  full_audio_output = gr.Audio(label="Generated Audio", visible=True)
405
  ner_output = gr.Textbox(label="Extracted Entities", visible=True, interactive=False)
406
+
 
 
407
  default_entity_types = gr.Textbox(label="Default Entity Types", value="PERSON, Organization, location, Date, PRODUCT, EVENT", interactive=True)
408
  custom_entity_types = gr.Textbox(label="Custom Entity Types", placeholder="Enter additional entity types (comma-separated)", interactive=True)
409
 
 
423
  show_progress=True
424
  )
425
 
426
+ # Step 3: Summarization (Generate Summary Before Enabling TTS Button)
427
+ def generate_summary_and_enable_tts(text):
428
+ summary = hierarchical_summarization(text)
429
+ return summary, gr.update(visible=True) # Enable the TTS button only after summary is generated
430
+
431
+ # Summarization
432
  extracted_text.change(
433
+ generate_summary_and_enable_tts,
434
  inputs=[extracted_text],
435
+ outputs=[summary_output, process_audio_button],
436
  show_progress=True
437
  )
438
 
439
+ # Audio Generation
440
  process_audio_button.click(
441
  lambda text, summary, lang, voice, tts_choice: (
442
  None, # Clear previous audio
 
449
  show_progress=True
450
  )
451
 
452
+ # NER Extraction
453
  process_ner_button.click(
454
 
455
  extract_entities_with_gliner,
456
 
457
+ inputs=[extracted_text, default_entity_types, custom_entity_types],
458
  outputs=[ner_output]
459
  )
460
 
461
+ demo.launch(share=True)