Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -26,11 +26,7 @@ import tempfile
|
|
26 |
nltk.download("punkt")
|
27 |
nltk.download("punkt_tab")
|
28 |
|
29 |
-
|
30 |
-
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
31 |
-
print(f"✅ Using {DEVICE.upper()} for TTS and Summarization")
|
32 |
-
|
33 |
-
kokoro_tts = KPipeline(lang_code='a', device=DEVICE)
|
34 |
|
35 |
# Supported TTS Languages
|
36 |
SUPPORTED_TTS_LANGUAGES = {
|
@@ -53,7 +49,7 @@ model_name = "facebook/bart-large-cnn"
|
|
53 |
|
54 |
try:
|
55 |
tokenizer = BartTokenizer.from_pretrained(model_name, cache_dir=os.path.join(os.getcwd(), ".cache"))
|
56 |
-
model = BartForConditionalGeneration.from_pretrained(model_name, cache_dir=os.path.join(os.getcwd(), ".cache"))
|
57 |
|
58 |
except Exception as e:
|
59 |
raise RuntimeError(f"Error loading BART model: {e}")
|
@@ -106,7 +102,7 @@ def fetch_and_display_content(url):
|
|
106 |
# Add detected language to metadata
|
107 |
metadata["Detected Language"] = detected_lang.upper()
|
108 |
|
109 |
-
|
110 |
cleaned_text,
|
111 |
metadata,
|
112 |
detected_lang,
|
@@ -283,10 +279,11 @@ def summarize_text(text, max_input_tokens=1024, max_output_tokens=200):
|
|
283 |
inputs = tokenizer.encode("summarize: " + text, return_tensors="pt", max_length=max_input_tokens, truncation=True)
|
284 |
summary_ids = model.generate(inputs, max_length=max_output_tokens, min_length=50, length_penalty=2.0, num_beams=4, early_stopping=True)
|
285 |
return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
286 |
-
|
|
|
287 |
def hierarchical_summarization(text):
|
288 |
"""Performs hierarchical summarization by chunking content first."""
|
289 |
-
print(f"✅ Summarization will run on: {DEVICE.upper()}")
|
290 |
|
291 |
if len(text) > 10000:
|
292 |
print("⚠️ Warning: Large input text detected. Summarization may take longer than usual.")
|
@@ -299,7 +296,7 @@ def hierarchical_summarization(text):
|
|
299 |
padding=True,
|
300 |
truncation=True,
|
301 |
max_length=1024
|
302 |
-
)
|
303 |
|
304 |
#Generate the summary
|
305 |
summary_ids = model.generate(
|
@@ -403,7 +400,9 @@ with gr.Blocks() as demo:
|
|
403 |
summary_output = gr.Textbox(label="Summary", visible=True, interactive=False)
|
404 |
full_audio_output = gr.Audio(label="Generated Audio", visible=True)
|
405 |
ner_output = gr.Textbox(label="Extracted Entities", visible=True, interactive=False)
|
406 |
-
|
|
|
|
|
407 |
default_entity_types = gr.Textbox(label="Default Entity Types", value="PERSON, Organization, location, Date, PRODUCT, EVENT", interactive=True)
|
408 |
custom_entity_types = gr.Textbox(label="Custom Entity Types", placeholder="Enter additional entity types (comma-separated)", interactive=True)
|
409 |
|
@@ -454,7 +453,7 @@ with gr.Blocks() as demo:
|
|
454 |
|
455 |
extract_entities_with_gliner,
|
456 |
|
457 |
-
|
458 |
outputs=[ner_output]
|
459 |
)
|
460 |
|
|
|
26 |
nltk.download("punkt")
|
27 |
nltk.download("punkt_tab")
|
28 |
|
29 |
+
kokoro_tts = KPipeline(lang_code='a')
|
|
|
|
|
|
|
|
|
30 |
|
31 |
# Supported TTS Languages
|
32 |
SUPPORTED_TTS_LANGUAGES = {
|
|
|
49 |
|
50 |
try:
|
51 |
tokenizer = BartTokenizer.from_pretrained(model_name, cache_dir=os.path.join(os.getcwd(), ".cache"))
|
52 |
+
model = BartForConditionalGeneration.from_pretrained(model_name, cache_dir=os.path.join(os.getcwd(), ".cache"))
|
53 |
|
54 |
except Exception as e:
|
55 |
raise RuntimeError(f"Error loading BART model: {e}")
|
|
|
102 |
# Add detected language to metadata
|
103 |
metadata["Detected Language"] = detected_lang.upper()
|
104 |
|
105 |
+
return (
|
106 |
cleaned_text,
|
107 |
metadata,
|
108 |
detected_lang,
|
|
|
279 |
inputs = tokenizer.encode("summarize: " + text, return_tensors="pt", max_length=max_input_tokens, truncation=True)
|
280 |
summary_ids = model.generate(inputs, max_length=max_output_tokens, min_length=50, length_penalty=2.0, num_beams=4, early_stopping=True)
|
281 |
return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
282 |
+
|
283 |
+
@spaces.GPU(duration=1000)
|
284 |
def hierarchical_summarization(text):
|
285 |
"""Performs hierarchical summarization by chunking content first."""
|
286 |
+
#print(f"✅ Summarization will run on: {DEVICE.upper()}")
|
287 |
|
288 |
if len(text) > 10000:
|
289 |
print("⚠️ Warning: Large input text detected. Summarization may take longer than usual.")
|
|
|
296 |
padding=True,
|
297 |
truncation=True,
|
298 |
max_length=1024
|
299 |
+
)
|
300 |
|
301 |
#Generate the summary
|
302 |
summary_ids = model.generate(
|
|
|
400 |
summary_output = gr.Textbox(label="Summary", visible=True, interactive=False)
|
401 |
full_audio_output = gr.Audio(label="Generated Audio", visible=True)
|
402 |
ner_output = gr.Textbox(label="Extracted Entities", visible=True, interactive=False)
|
403 |
+
|
404 |
+
|
405 |
+
|
406 |
default_entity_types = gr.Textbox(label="Default Entity Types", value="PERSON, Organization, location, Date, PRODUCT, EVENT", interactive=True)
|
407 |
custom_entity_types = gr.Textbox(label="Custom Entity Types", placeholder="Enter additional entity types (comma-separated)", interactive=True)
|
408 |
|
|
|
453 |
|
454 |
extract_entities_with_gliner,
|
455 |
|
456 |
+
inputs=[extracted_text, default_entity_types, custom_entity_types],
|
457 |
outputs=[ner_output]
|
458 |
)
|
459 |
|