PuristanLabs1 commited on
Commit
8b12154
·
verified ·
1 Parent(s): 392490b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -12
app.py CHANGED
@@ -26,11 +26,7 @@ import tempfile
26
  nltk.download("punkt")
27
  nltk.download("punkt_tab")
28
 
29
- # Automatically select device based on hardware availability
30
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
31
- print(f"✅ Using {DEVICE.upper()} for TTS and Summarization")
32
-
33
- kokoro_tts = KPipeline(lang_code='a', device=DEVICE)
34
 
35
  # Supported TTS Languages
36
  SUPPORTED_TTS_LANGUAGES = {
@@ -53,7 +49,7 @@ model_name = "facebook/bart-large-cnn"
53
 
54
  try:
55
  tokenizer = BartTokenizer.from_pretrained(model_name, cache_dir=os.path.join(os.getcwd(), ".cache"))
56
- model = BartForConditionalGeneration.from_pretrained(model_name, cache_dir=os.path.join(os.getcwd(), ".cache")).to(DEVICE)
57
 
58
  except Exception as e:
59
  raise RuntimeError(f"Error loading BART model: {e}")
@@ -106,7 +102,7 @@ def fetch_and_display_content(url):
106
  # Add detected language to metadata
107
  metadata["Detected Language"] = detected_lang.upper()
108
 
109
- return (
110
  cleaned_text,
111
  metadata,
112
  detected_lang,
@@ -283,10 +279,11 @@ def summarize_text(text, max_input_tokens=1024, max_output_tokens=200):
283
  inputs = tokenizer.encode("summarize: " + text, return_tensors="pt", max_length=max_input_tokens, truncation=True)
284
  summary_ids = model.generate(inputs, max_length=max_output_tokens, min_length=50, length_penalty=2.0, num_beams=4, early_stopping=True)
285
  return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
286
-
 
287
  def hierarchical_summarization(text):
288
  """Performs hierarchical summarization by chunking content first."""
289
- print(f"✅ Summarization will run on: {DEVICE.upper()}")
290
 
291
  if len(text) > 10000:
292
  print("⚠️ Warning: Large input text detected. Summarization may take longer than usual.")
@@ -299,7 +296,7 @@ def hierarchical_summarization(text):
299
  padding=True,
300
  truncation=True,
301
  max_length=1024
302
- ).to(DEVICE)
303
 
304
  #Generate the summary
305
  summary_ids = model.generate(
@@ -403,7 +400,9 @@ with gr.Blocks() as demo:
403
  summary_output = gr.Textbox(label="Summary", visible=True, interactive=False)
404
  full_audio_output = gr.Audio(label="Generated Audio", visible=True)
405
  ner_output = gr.Textbox(label="Extracted Entities", visible=True, interactive=False)
406
-
 
 
407
  default_entity_types = gr.Textbox(label="Default Entity Types", value="PERSON, Organization, location, Date, PRODUCT, EVENT", interactive=True)
408
  custom_entity_types = gr.Textbox(label="Custom Entity Types", placeholder="Enter additional entity types (comma-separated)", interactive=True)
409
 
@@ -454,7 +453,7 @@ with gr.Blocks() as demo:
454
 
455
  extract_entities_with_gliner,
456
 
457
- inputs=[extracted_text, default_entity_types, custom_entity_types],
458
  outputs=[ner_output]
459
  )
460
 
 
26
  nltk.download("punkt")
27
  nltk.download("punkt_tab")
28
 
29
+ kokoro_tts = KPipeline(lang_code='a')
 
 
 
 
30
 
31
  # Supported TTS Languages
32
  SUPPORTED_TTS_LANGUAGES = {
 
49
 
50
  try:
51
  tokenizer = BartTokenizer.from_pretrained(model_name, cache_dir=os.path.join(os.getcwd(), ".cache"))
52
+ model = BartForConditionalGeneration.from_pretrained(model_name, cache_dir=os.path.join(os.getcwd(), ".cache"))
53
 
54
  except Exception as e:
55
  raise RuntimeError(f"Error loading BART model: {e}")
 
102
  # Add detected language to metadata
103
  metadata["Detected Language"] = detected_lang.upper()
104
 
105
+ return (
106
  cleaned_text,
107
  metadata,
108
  detected_lang,
 
279
  inputs = tokenizer.encode("summarize: " + text, return_tensors="pt", max_length=max_input_tokens, truncation=True)
280
  summary_ids = model.generate(inputs, max_length=max_output_tokens, min_length=50, length_penalty=2.0, num_beams=4, early_stopping=True)
281
  return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
282
+
283
+ @spaces.GPU(duration=1000)
284
  def hierarchical_summarization(text):
285
  """Performs hierarchical summarization by chunking content first."""
286
+ #print(f"✅ Summarization will run on: {DEVICE.upper()}")
287
 
288
  if len(text) > 10000:
289
  print("⚠️ Warning: Large input text detected. Summarization may take longer than usual.")
 
296
  padding=True,
297
  truncation=True,
298
  max_length=1024
299
+ )
300
 
301
  #Generate the summary
302
  summary_ids = model.generate(
 
400
  summary_output = gr.Textbox(label="Summary", visible=True, interactive=False)
401
  full_audio_output = gr.Audio(label="Generated Audio", visible=True)
402
  ner_output = gr.Textbox(label="Extracted Entities", visible=True, interactive=False)
403
+
404
+
405
+
406
  default_entity_types = gr.Textbox(label="Default Entity Types", value="PERSON, Organization, location, Date, PRODUCT, EVENT", interactive=True)
407
  custom_entity_types = gr.Textbox(label="Custom Entity Types", placeholder="Enter additional entity types (comma-separated)", interactive=True)
408
 
 
453
 
454
  extract_entities_with_gliner,
455
 
456
+ inputs=[extracted_text, default_entity_types, custom_entity_types],
457
  outputs=[ner_output]
458
  )
459