spacesedan commited on
Commit
372b4a1
·
1 Parent(s): cba823e
Files changed (1) hide show
  1. app.py +17 -8
app.py CHANGED
@@ -4,7 +4,14 @@ from transformers import pipeline, AutoTokenizer
4
  from typing import List
5
  import logging
6
  import torch
7
- import re
 
 
 
 
 
 
 
8
 
9
  app = FastAPI()
10
 
@@ -21,7 +28,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_name)
21
 
22
  # Token constraints
23
  MAX_MODEL_TOKENS = 1024
24
- SAFE_CHUNK_SIZE = 700
25
 
26
  # Pydantic schemas
27
  class SummarizationItem(BaseModel):
@@ -38,9 +45,9 @@ class SummarizationResponseItem(BaseModel):
38
  class BatchSummarizationResponse(BaseModel):
39
  summaries: List[SummarizationResponseItem]
40
 
41
- # Sentence-based chunking
42
  def split_sentences(text: str) -> list[str]:
43
- return re.split(r'(?<=[.!?])\s+', text.strip())
44
 
45
  def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]:
46
  sentences = split_sentences(text)
@@ -49,7 +56,7 @@ def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]:
49
 
50
  for sentence in sentences:
51
  tentative_chunk = " ".join(current_chunk_sentences + [sentence])
52
- token_count = len(tokenizer.encode(tentative_chunk, truncation=False))
53
 
54
  if token_count <= max_tokens:
55
  current_chunk_sentences.append(sentence)
@@ -64,16 +71,16 @@ def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]:
64
  # Final filter: ensure nothing slipped through
65
  final_chunks = []
66
  for chunk in chunks:
67
- encoded = tokenizer(chunk, return_tensors="pt", truncation=False)
68
  token_len = encoded["input_ids"].shape[1]
 
69
  if token_len <= MAX_MODEL_TOKENS:
70
  final_chunks.append(chunk)
71
  else:
72
- logger.warning(f"[CHUNKING] Dropped oversized chunk: {token_len} tokens")
73
 
74
  return final_chunks
75
 
76
- # Summarization endpoint
77
  @app.post("/summarize", response_model=BatchSummarizationResponse)
78
  async def summarize_batch(request: BatchSummarizationRequest):
79
  all_chunks = []
@@ -91,6 +98,7 @@ async def summarize_batch(request: BatchSummarizationRequest):
91
  logger.error("No valid chunks after filtering. Returning empty response.")
92
  return {"summaries": []}
93
 
 
94
  summaries = summarizer(
95
  all_chunks,
96
  max_length=150,
@@ -100,6 +108,7 @@ async def summarize_batch(request: BatchSummarizationRequest):
100
  batch_size=4
101
  )
102
 
 
103
  summary_map = {}
104
  for content_id, result in zip(chunk_map, summaries):
105
  summary_map.setdefault(content_id, []).append(result["summary_text"])
 
4
  from typing import List
5
  import logging
6
  import torch
7
+ import nltk
8
+ import os
9
+
10
+ from nltk.tokenize import sent_tokenize
11
+
12
+ # Download punkt tokenizer if not already present
13
+ nltk_data_path = os.getenv("NLTK_DATA", "/home/user/nltk_data")
14
+ nltk.download("punkt", download_dir=nltk_data_path)
15
 
16
  app = FastAPI()
17
 
 
28
 
29
  # Token constraints
30
  MAX_MODEL_TOKENS = 1024
31
+ SAFE_CHUNK_SIZE = 650 # Lowered for extra safety
32
 
33
  # Pydantic schemas
34
  class SummarizationItem(BaseModel):
 
45
  class BatchSummarizationResponse(BaseModel):
46
  summaries: List[SummarizationResponseItem]
47
 
48
+ # Sentence-based chunking using nltk
49
  def split_sentences(text: str) -> list[str]:
50
+ return sent_tokenize(text.strip())
51
 
52
  def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]:
53
  sentences = split_sentences(text)
 
56
 
57
  for sentence in sentences:
58
  tentative_chunk = " ".join(current_chunk_sentences + [sentence])
59
+ token_count = len(tokenizer.encode(tentative_chunk, add_special_tokens=False))
60
 
61
  if token_count <= max_tokens:
62
  current_chunk_sentences.append(sentence)
 
71
  # Final filter: ensure nothing slipped through
72
  final_chunks = []
73
  for chunk in chunks:
74
+ encoded = tokenizer(chunk, return_tensors="pt", truncation=False, add_special_tokens=False)
75
  token_len = encoded["input_ids"].shape[1]
76
+
77
  if token_len <= MAX_MODEL_TOKENS:
78
  final_chunks.append(chunk)
79
  else:
80
+ logger.warning(f"[CHUNKING] Dropped oversized chunk ({token_len} tokens): {chunk[:100]}...")
81
 
82
  return final_chunks
83
 
 
84
  @app.post("/summarize", response_model=BatchSummarizationResponse)
85
  async def summarize_batch(request: BatchSummarizationRequest):
86
  all_chunks = []
 
98
  logger.error("No valid chunks after filtering. Returning empty response.")
99
  return {"summaries": []}
100
 
101
+ # Batch inference (safe, since we're now filtering properly)
102
  summaries = summarizer(
103
  all_chunks,
104
  max_length=150,
 
108
  batch_size=4
109
  )
110
 
111
+ # Combine summaries by content_id
112
  summary_map = {}
113
  for content_id, result in zip(chunk_map, summaries):
114
  summary_map.setdefault(content_id, []).append(result["summary_text"])