spacesedan commited on
Commit
a67ba36
·
1 Parent(s): 4992a8e

split those sentences

Browse files
Files changed (1) hide show
  1. app.py +40 -10
app.py CHANGED
@@ -6,10 +6,11 @@ import logging
6
  import torch
7
  import nltk
8
  import os
9
-
10
 
11
  from nltk.tokenize import sent_tokenize
12
 
 
13
  nltk_data_path = os.getenv("NLTK_DATA", "/home/user/nltk_data")
14
  nltk.data.path.append(nltk_data_path)
15
 
@@ -28,7 +29,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_name)
28
 
29
  # Token constraints
30
  MAX_MODEL_TOKENS = 1024
31
- SAFE_CHUNK_SIZE = 650 # Lowered for extra safety
32
 
33
  # Pydantic schemas
34
  class SummarizationItem(BaseModel):
@@ -45,10 +46,40 @@ class SummarizationResponseItem(BaseModel):
45
  class BatchSummarizationResponse(BaseModel):
46
  summaries: List[SummarizationResponseItem]
47
 
48
- # Sentence-based chunking using nltk
49
- def split_sentences(text: str) -> list[str]:
50
- return sent_tokenize(text.strip())
 
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]:
53
  sentences = split_sentences(text)
54
  chunks = []
@@ -56,7 +87,7 @@ def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]:
56
 
57
  for sentence in sentences:
58
  tentative_chunk = " ".join(current_chunk_sentences + [sentence])
59
- token_count = len(tokenizer.encode(tentative_chunk, add_special_tokens=False))
60
 
61
  if token_count <= max_tokens:
62
  current_chunk_sentences.append(sentence)
@@ -68,12 +99,11 @@ def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]:
68
  if current_chunk_sentences:
69
  chunks.append(" ".join(current_chunk_sentences))
70
 
71
- # Final filter: ensure nothing slipped through
72
  final_chunks = []
73
  for chunk in chunks:
74
  encoded = tokenizer(chunk, return_tensors="pt", truncation=False, add_special_tokens=False)
75
  token_len = encoded["input_ids"].shape[1]
76
-
77
  if token_len <= MAX_MODEL_TOKENS:
78
  final_chunks.append(chunk)
79
  else:
@@ -98,7 +128,7 @@ async def summarize_batch(request: BatchSummarizationRequest):
98
  logger.error("No valid chunks after filtering. Returning empty response.")
99
  return {"summaries": []}
100
 
101
- # Batch inference (safe, since we're now filtering properly)
102
  summaries = summarizer(
103
  all_chunks,
104
  max_length=150,
@@ -108,7 +138,7 @@ async def summarize_batch(request: BatchSummarizationRequest):
108
  batch_size=4
109
  )
110
 
111
- # Combine summaries by content_id
112
  summary_map = {}
113
  for content_id, result in zip(chunk_map, summaries):
114
  summary_map.setdefault(content_id, []).append(result["summary_text"])
 
6
  import torch
7
  import nltk
8
  import os
9
+ import re
10
 
11
  from nltk.tokenize import sent_tokenize
12
 
13
+ # Configure NLTK to use preloaded data path
14
  nltk_data_path = os.getenv("NLTK_DATA", "/home/user/nltk_data")
15
  nltk.data.path.append(nltk_data_path)
16
 
 
29
 
30
  # Token constraints
31
  MAX_MODEL_TOKENS = 1024
32
+ SAFE_CHUNK_SIZE = 600 # Reduced to leave room for special tokens
33
 
34
  # Pydantic schemas
35
  class SummarizationItem(BaseModel):
 
46
  class BatchSummarizationResponse(BaseModel):
47
  summaries: List[SummarizationResponseItem]
48
 
49
+ # Sentence splitter with fallback for long sentences
50
+ def split_sentences(text: str, max_sentence_tokens: int = SAFE_CHUNK_SIZE) -> list[str]:
51
+ sentences = sent_tokenize(text.strip())
52
+ split_results = []
53
 
54
+ for sentence in sentences:
55
+ token_len = len(tokenizer.tokenize(sentence))
56
+ if token_len <= max_sentence_tokens:
57
+ split_results.append(sentence)
58
+ else:
59
+ # Fallback: split by commas/semicolons
60
+ sub_sentences = re.split(r'[;,:]\s+', sentence)
61
+ for sub in sub_sentences:
62
+ sub = sub.strip()
63
+ if not sub:
64
+ continue
65
+ if len(tokenizer.tokenize(sub)) <= max_sentence_tokens:
66
+ split_results.append(sub)
67
+ else:
68
+ # Final fallback: hard-split by word
69
+ words = sub.split()
70
+ buffer = []
71
+ for word in words:
72
+ buffer.append(word)
73
+ current = " ".join(buffer)
74
+ if len(tokenizer.tokenize(current)) > max_sentence_tokens:
75
+ split_results.append(" ".join(buffer[:-1]))
76
+ buffer = [word]
77
+ if buffer:
78
+ split_results.append(" ".join(buffer))
79
+
80
+ return split_results
81
+
82
+ # Chunking based on token length
83
  def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]:
84
  sentences = split_sentences(text)
85
  chunks = []
 
87
 
88
  for sentence in sentences:
89
  tentative_chunk = " ".join(current_chunk_sentences + [sentence])
90
+ token_count = len(tokenizer.tokenize(tentative_chunk))
91
 
92
  if token_count <= max_tokens:
93
  current_chunk_sentences.append(sentence)
 
99
  if current_chunk_sentences:
100
  chunks.append(" ".join(current_chunk_sentences))
101
 
102
+ # Final model-safe filtering
103
  final_chunks = []
104
  for chunk in chunks:
105
  encoded = tokenizer(chunk, return_tensors="pt", truncation=False, add_special_tokens=False)
106
  token_len = encoded["input_ids"].shape[1]
 
107
  if token_len <= MAX_MODEL_TOKENS:
108
  final_chunks.append(chunk)
109
  else:
 
128
  logger.error("No valid chunks after filtering. Returning empty response.")
129
  return {"summaries": []}
130
 
131
+ # Inference
132
  summaries = summarizer(
133
  all_chunks,
134
  max_length=150,
 
138
  batch_size=4
139
  )
140
 
141
+ # Merge summaries by content_id
142
  summary_map = {}
143
  for content_id, result in zip(chunk_map, summaries):
144
  summary_map.setdefault(content_id, []).append(result["summary_text"])