Spaces:
Running
Running
Commit
·
372b4a1
1
Parent(s):
cba823e
come on
Browse files
app.py
CHANGED
@@ -4,7 +4,14 @@ from transformers import pipeline, AutoTokenizer
|
|
4 |
from typing import List
|
5 |
import logging
|
6 |
import torch
|
7 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
app = FastAPI()
|
10 |
|
@@ -21,7 +28,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
21 |
|
22 |
# Token constraints
|
23 |
MAX_MODEL_TOKENS = 1024
|
24 |
-
SAFE_CHUNK_SIZE =
|
25 |
|
26 |
# Pydantic schemas
|
27 |
class SummarizationItem(BaseModel):
|
@@ -38,9 +45,9 @@ class SummarizationResponseItem(BaseModel):
|
|
38 |
class BatchSummarizationResponse(BaseModel):
|
39 |
summaries: List[SummarizationResponseItem]
|
40 |
|
41 |
-
# Sentence-based chunking
|
42 |
def split_sentences(text: str) -> list[str]:
|
43 |
-
return
|
44 |
|
45 |
def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]:
|
46 |
sentences = split_sentences(text)
|
@@ -49,7 +56,7 @@ def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]:
|
|
49 |
|
50 |
for sentence in sentences:
|
51 |
tentative_chunk = " ".join(current_chunk_sentences + [sentence])
|
52 |
-
token_count = len(tokenizer.encode(tentative_chunk,
|
53 |
|
54 |
if token_count <= max_tokens:
|
55 |
current_chunk_sentences.append(sentence)
|
@@ -64,16 +71,16 @@ def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]:
|
|
64 |
# Final filter: ensure nothing slipped through
|
65 |
final_chunks = []
|
66 |
for chunk in chunks:
|
67 |
-
encoded = tokenizer(chunk, return_tensors="pt", truncation=False)
|
68 |
token_len = encoded["input_ids"].shape[1]
|
|
|
69 |
if token_len <= MAX_MODEL_TOKENS:
|
70 |
final_chunks.append(chunk)
|
71 |
else:
|
72 |
-
logger.warning(f"[CHUNKING] Dropped oversized chunk
|
73 |
|
74 |
return final_chunks
|
75 |
|
76 |
-
# Summarization endpoint
|
77 |
@app.post("/summarize", response_model=BatchSummarizationResponse)
|
78 |
async def summarize_batch(request: BatchSummarizationRequest):
|
79 |
all_chunks = []
|
@@ -91,6 +98,7 @@ async def summarize_batch(request: BatchSummarizationRequest):
|
|
91 |
logger.error("No valid chunks after filtering. Returning empty response.")
|
92 |
return {"summaries": []}
|
93 |
|
|
|
94 |
summaries = summarizer(
|
95 |
all_chunks,
|
96 |
max_length=150,
|
@@ -100,6 +108,7 @@ async def summarize_batch(request: BatchSummarizationRequest):
|
|
100 |
batch_size=4
|
101 |
)
|
102 |
|
|
|
103 |
summary_map = {}
|
104 |
for content_id, result in zip(chunk_map, summaries):
|
105 |
summary_map.setdefault(content_id, []).append(result["summary_text"])
|
|
|
4 |
from typing import List
|
5 |
import logging
|
6 |
import torch
|
7 |
+
import nltk
|
8 |
+
import os
|
9 |
+
|
10 |
+
from nltk.tokenize import sent_tokenize
|
11 |
+
|
12 |
+
# Download punkt tokenizer if not already present
|
13 |
+
nltk_data_path = os.getenv("NLTK_DATA", "/home/user/nltk_data")
|
14 |
+
nltk.download("punkt", download_dir=nltk_data_path)
|
15 |
|
16 |
app = FastAPI()
|
17 |
|
|
|
28 |
|
29 |
# Token constraints
|
30 |
MAX_MODEL_TOKENS = 1024
|
31 |
+
SAFE_CHUNK_SIZE = 650 # Lowered for extra safety
|
32 |
|
33 |
# Pydantic schemas
|
34 |
class SummarizationItem(BaseModel):
|
|
|
45 |
class BatchSummarizationResponse(BaseModel):
|
46 |
summaries: List[SummarizationResponseItem]
|
47 |
|
48 |
+
# Sentence-based chunking using nltk
|
49 |
def split_sentences(text: str) -> list[str]:
|
50 |
+
return sent_tokenize(text.strip())
|
51 |
|
52 |
def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]:
|
53 |
sentences = split_sentences(text)
|
|
|
56 |
|
57 |
for sentence in sentences:
|
58 |
tentative_chunk = " ".join(current_chunk_sentences + [sentence])
|
59 |
+
token_count = len(tokenizer.encode(tentative_chunk, add_special_tokens=False))
|
60 |
|
61 |
if token_count <= max_tokens:
|
62 |
current_chunk_sentences.append(sentence)
|
|
|
71 |
# Final filter: ensure nothing slipped through
|
72 |
final_chunks = []
|
73 |
for chunk in chunks:
|
74 |
+
encoded = tokenizer(chunk, return_tensors="pt", truncation=False, add_special_tokens=False)
|
75 |
token_len = encoded["input_ids"].shape[1]
|
76 |
+
|
77 |
if token_len <= MAX_MODEL_TOKENS:
|
78 |
final_chunks.append(chunk)
|
79 |
else:
|
80 |
+
logger.warning(f"[CHUNKING] Dropped oversized chunk ({token_len} tokens): {chunk[:100]}...")
|
81 |
|
82 |
return final_chunks
|
83 |
|
|
|
84 |
@app.post("/summarize", response_model=BatchSummarizationResponse)
|
85 |
async def summarize_batch(request: BatchSummarizationRequest):
|
86 |
all_chunks = []
|
|
|
98 |
logger.error("No valid chunks after filtering. Returning empty response.")
|
99 |
return {"summaries": []}
|
100 |
|
101 |
+
# Batch inference (safe, since we're now filtering properly)
|
102 |
summaries = summarizer(
|
103 |
all_chunks,
|
104 |
max_length=150,
|
|
|
108 |
batch_size=4
|
109 |
)
|
110 |
|
111 |
+
# Combine summaries by content_id
|
112 |
summary_map = {}
|
113 |
for content_id, result in zip(chunk_map, summaries):
|
114 |
summary_map.setdefault(content_id, []).append(result["summary_text"])
|