Spaces:
Running
Running
Commit
·
a67ba36
1
Parent(s):
4992a8e
split those sentences
Browse files
app.py
CHANGED
@@ -6,10 +6,11 @@ import logging
|
|
6 |
import torch
|
7 |
import nltk
|
8 |
import os
|
9 |
-
|
10 |
|
11 |
from nltk.tokenize import sent_tokenize
|
12 |
|
|
|
13 |
nltk_data_path = os.getenv("NLTK_DATA", "/home/user/nltk_data")
|
14 |
nltk.data.path.append(nltk_data_path)
|
15 |
|
@@ -28,7 +29,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
28 |
|
29 |
# Token constraints
|
30 |
MAX_MODEL_TOKENS = 1024
|
31 |
-
SAFE_CHUNK_SIZE =
|
32 |
|
33 |
# Pydantic schemas
|
34 |
class SummarizationItem(BaseModel):
|
@@ -45,10 +46,40 @@ class SummarizationResponseItem(BaseModel):
|
|
45 |
class BatchSummarizationResponse(BaseModel):
|
46 |
summaries: List[SummarizationResponseItem]
|
47 |
|
48 |
-
# Sentence
|
49 |
-
def split_sentences(text: str) -> list[str]:
|
50 |
-
|
|
|
51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]:
|
53 |
sentences = split_sentences(text)
|
54 |
chunks = []
|
@@ -56,7 +87,7 @@ def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]:
|
|
56 |
|
57 |
for sentence in sentences:
|
58 |
tentative_chunk = " ".join(current_chunk_sentences + [sentence])
|
59 |
-
token_count = len(tokenizer.
|
60 |
|
61 |
if token_count <= max_tokens:
|
62 |
current_chunk_sentences.append(sentence)
|
@@ -68,12 +99,11 @@ def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]:
|
|
68 |
if current_chunk_sentences:
|
69 |
chunks.append(" ".join(current_chunk_sentences))
|
70 |
|
71 |
-
# Final
|
72 |
final_chunks = []
|
73 |
for chunk in chunks:
|
74 |
encoded = tokenizer(chunk, return_tensors="pt", truncation=False, add_special_tokens=False)
|
75 |
token_len = encoded["input_ids"].shape[1]
|
76 |
-
|
77 |
if token_len <= MAX_MODEL_TOKENS:
|
78 |
final_chunks.append(chunk)
|
79 |
else:
|
@@ -98,7 +128,7 @@ async def summarize_batch(request: BatchSummarizationRequest):
|
|
98 |
logger.error("No valid chunks after filtering. Returning empty response.")
|
99 |
return {"summaries": []}
|
100 |
|
101 |
-
#
|
102 |
summaries = summarizer(
|
103 |
all_chunks,
|
104 |
max_length=150,
|
@@ -108,7 +138,7 @@ async def summarize_batch(request: BatchSummarizationRequest):
|
|
108 |
batch_size=4
|
109 |
)
|
110 |
|
111 |
-
#
|
112 |
summary_map = {}
|
113 |
for content_id, result in zip(chunk_map, summaries):
|
114 |
summary_map.setdefault(content_id, []).append(result["summary_text"])
|
|
|
6 |
import torch
|
7 |
import nltk
|
8 |
import os
|
9 |
+
import re
|
10 |
|
11 |
from nltk.tokenize import sent_tokenize
|
12 |
|
13 |
+
# Configure NLTK to use preloaded data path
|
14 |
nltk_data_path = os.getenv("NLTK_DATA", "/home/user/nltk_data")
|
15 |
nltk.data.path.append(nltk_data_path)
|
16 |
|
|
|
29 |
|
30 |
# Token constraints
|
31 |
MAX_MODEL_TOKENS = 1024
|
32 |
+
SAFE_CHUNK_SIZE = 600 # Reduced to leave room for special tokens
|
33 |
|
34 |
# Pydantic schemas
|
35 |
class SummarizationItem(BaseModel):
|
|
|
46 |
class BatchSummarizationResponse(BaseModel):
|
47 |
summaries: List[SummarizationResponseItem]
|
48 |
|
49 |
+
# Sentence splitter with fallback for long sentences
|
50 |
+
def split_sentences(text: str, max_sentence_tokens: int = SAFE_CHUNK_SIZE) -> list[str]:
|
51 |
+
sentences = sent_tokenize(text.strip())
|
52 |
+
split_results = []
|
53 |
|
54 |
+
for sentence in sentences:
|
55 |
+
token_len = len(tokenizer.tokenize(sentence))
|
56 |
+
if token_len <= max_sentence_tokens:
|
57 |
+
split_results.append(sentence)
|
58 |
+
else:
|
59 |
+
# Fallback: split by commas/semicolons
|
60 |
+
sub_sentences = re.split(r'[;,:]\s+', sentence)
|
61 |
+
for sub in sub_sentences:
|
62 |
+
sub = sub.strip()
|
63 |
+
if not sub:
|
64 |
+
continue
|
65 |
+
if len(tokenizer.tokenize(sub)) <= max_sentence_tokens:
|
66 |
+
split_results.append(sub)
|
67 |
+
else:
|
68 |
+
# Final fallback: hard-split by word
|
69 |
+
words = sub.split()
|
70 |
+
buffer = []
|
71 |
+
for word in words:
|
72 |
+
buffer.append(word)
|
73 |
+
current = " ".join(buffer)
|
74 |
+
if len(tokenizer.tokenize(current)) > max_sentence_tokens:
|
75 |
+
split_results.append(" ".join(buffer[:-1]))
|
76 |
+
buffer = [word]
|
77 |
+
if buffer:
|
78 |
+
split_results.append(" ".join(buffer))
|
79 |
+
|
80 |
+
return split_results
|
81 |
+
|
82 |
+
# Chunking based on token length
|
83 |
def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]:
|
84 |
sentences = split_sentences(text)
|
85 |
chunks = []
|
|
|
87 |
|
88 |
for sentence in sentences:
|
89 |
tentative_chunk = " ".join(current_chunk_sentences + [sentence])
|
90 |
+
token_count = len(tokenizer.tokenize(tentative_chunk))
|
91 |
|
92 |
if token_count <= max_tokens:
|
93 |
current_chunk_sentences.append(sentence)
|
|
|
99 |
if current_chunk_sentences:
|
100 |
chunks.append(" ".join(current_chunk_sentences))
|
101 |
|
102 |
+
# Final model-safe filtering
|
103 |
final_chunks = []
|
104 |
for chunk in chunks:
|
105 |
encoded = tokenizer(chunk, return_tensors="pt", truncation=False, add_special_tokens=False)
|
106 |
token_len = encoded["input_ids"].shape[1]
|
|
|
107 |
if token_len <= MAX_MODEL_TOKENS:
|
108 |
final_chunks.append(chunk)
|
109 |
else:
|
|
|
128 |
logger.error("No valid chunks after filtering. Returning empty response.")
|
129 |
return {"summaries": []}
|
130 |
|
131 |
+
# Inference
|
132 |
summaries = summarizer(
|
133 |
all_chunks,
|
134 |
max_length=150,
|
|
|
138 |
batch_size=4
|
139 |
)
|
140 |
|
141 |
+
# Merge summaries by content_id
|
142 |
summary_map = {}
|
143 |
for content_id, result in zip(chunk_map, summaries):
|
144 |
summary_map.setdefault(content_id, []).append(result["summary_text"])
|