spacesedan commited on
Commit
9e815e0
·
1 Parent(s): 45e1223
Files changed (2) hide show
  1. app.py +43 -20
  2. requirements.txt +3 -1
app.py CHANGED
@@ -4,20 +4,31 @@ from transformers import pipeline, AutoTokenizer
4
  from typing import List
5
  import logging
6
  import torch
 
 
7
 
 
8
  app = FastAPI()
9
 
10
  # Configure logging
11
  logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger("summarizer")
13
 
14
- # Faster and lighter summarization model
 
 
 
15
  model_name = "sshleifer/distilbart-cnn-12-6"
16
  device = 0 if torch.cuda.is_available() else -1
17
  logger.info(f"Running summarizer on {'GPU' if device == 0 else 'CPU'}")
18
  summarizer = pipeline("summarization", model=model_name, device=device)
19
  tokenizer = AutoTokenizer.from_pretrained(model_name)
20
 
 
 
 
 
 
21
  class SummarizationItem(BaseModel):
22
  content_id: str
23
  text: str
@@ -32,22 +43,38 @@ class SummarizationResponseItem(BaseModel):
32
  class BatchSummarizationResponse(BaseModel):
33
  summaries: List[SummarizationResponseItem]
34
 
35
- # Ensure no chunk ever exceeds model token limit
36
- MAX_MODEL_TOKENS = 1024
37
- SAFE_CHUNK_SIZE = 700
38
-
39
  def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]:
40
- tokens = tokenizer.encode(text, truncation=False)
41
  chunks = []
42
-
43
- for i in range(0, len(tokens), max_tokens):
44
- chunk_tokens = tokens[i:i + max_tokens]
45
- chunk_tokens = chunk_tokens[:MAX_MODEL_TOKENS]
46
- chunk = tokenizer.decode(chunk_tokens, skip_special_tokens=True)
47
- chunks.append(chunk)
48
-
49
- return chunks
50
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  @app.post("/summarize", response_model=BatchSummarizationResponse)
52
  async def summarize_batch(request: BatchSummarizationRequest):
53
  all_chunks = []
@@ -57,12 +84,8 @@ async def summarize_batch(request: BatchSummarizationRequest):
57
  token_count = len(tokenizer.encode(item.text, truncation=False))
58
  chunks = chunk_text(item.text)
59
  logger.info(f"[CHUNKING] content_id={item.content_id} token_len={token_count} num_chunks={len(chunks)}")
 
60
  for chunk in chunks:
61
- encoded = tokenizer(chunk, return_tensors="pt", truncation=False)
62
- final_len = encoded["input_ids"].shape[1]
63
- if final_len > MAX_MODEL_TOKENS:
64
- logger.warning(f"[SKIP] content_id={item.content_id} chunk still too long after decode: {final_len} tokens")
65
- continue
66
  all_chunks.append(chunk)
67
  chunk_map.append(item.content_id)
68
 
 
4
  from typing import List
5
  import logging
6
  import torch
7
+ import nltk
8
+ from nltk.tokenize import sent_tokenize
9
 
10
+ # FastAPI app init
11
  app = FastAPI()
12
 
13
  # Configure logging
14
  logging.basicConfig(level=logging.INFO)
15
  logger = logging.getLogger("summarizer")
16
 
17
+ # NLTK setup
18
+ nltk.download("punkt")
19
+
20
+ # Model config
21
  model_name = "sshleifer/distilbart-cnn-12-6"
22
  device = 0 if torch.cuda.is_available() else -1
23
  logger.info(f"Running summarizer on {'GPU' if device == 0 else 'CPU'}")
24
  summarizer = pipeline("summarization", model=model_name, device=device)
25
  tokenizer = AutoTokenizer.from_pretrained(model_name)
26
 
27
+ # Token limits
28
+ MAX_MODEL_TOKENS = 1024
29
+ SAFE_CHUNK_SIZE = 700 # Conservative chunk size to stay below 1024 after re-tokenization
30
+
31
+ # Input/output schemas
32
  class SummarizationItem(BaseModel):
33
  content_id: str
34
  text: str
 
43
  class BatchSummarizationResponse(BaseModel):
44
  summaries: List[SummarizationResponseItem]
45
 
46
+ # New safe chunking logic using NLTK
 
 
 
47
  def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]:
48
+ sentences = sent_tokenize(text)
49
  chunks = []
50
+ current_chunk = ""
51
+
52
+ for sentence in sentences:
53
+ temp_chunk = f"{current_chunk} {sentence}".strip()
54
+ token_count = len(tokenizer.encode(temp_chunk, truncation=False))
55
+
56
+ if token_count <= max_tokens:
57
+ current_chunk = temp_chunk
58
+ else:
59
+ if current_chunk:
60
+ chunks.append(current_chunk)
61
+ current_chunk = sentence
62
+
63
+ if current_chunk:
64
+ chunks.append(current_chunk)
65
+
66
+ final_chunks = []
67
+ for chunk in chunks:
68
+ encoded = tokenizer(chunk, return_tensors="pt", truncation=False)
69
+ actual_len = encoded["input_ids"].shape[1]
70
+ if actual_len <= MAX_MODEL_TOKENS:
71
+ final_chunks.append(chunk)
72
+ else:
73
+ logger.warning(f"[CHUNKING] Dropped chunk due to re-encoding overflow: {actual_len} tokens")
74
+
75
+ return final_chunks
76
+
77
+ # Main summarization endpoint
78
  @app.post("/summarize", response_model=BatchSummarizationResponse)
79
  async def summarize_batch(request: BatchSummarizationRequest):
80
  all_chunks = []
 
84
  token_count = len(tokenizer.encode(item.text, truncation=False))
85
  chunks = chunk_text(item.text)
86
  logger.info(f"[CHUNKING] content_id={item.content_id} token_len={token_count} num_chunks={len(chunks)}")
87
+
88
  for chunk in chunks:
 
 
 
 
 
89
  all_chunks.append(chunk)
90
  chunk_map.append(item.content_id)
91
 
requirements.txt CHANGED
@@ -1,4 +1,6 @@
1
  fastapi
2
  uvicorn[standard]
3
- torch
4
  transformers
 
 
 
 
1
  fastapi
2
  uvicorn[standard]
 
3
  transformers
4
+ torch
5
+ nltk
6
+ pydantic